git_stack/git/
repo.rs

1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub type Error = git2::Error;
5pub type Result<T, E = Error> = std::result::Result<T, E>;
6
7pub trait Repo {
8    fn path(&self) -> Option<&std::path::Path>;
9    fn user(&self) -> Option<std::rc::Rc<str>>;
10    fn push_remote(&self) -> &str;
11    fn pull_remote(&self) -> &str;
12
13    fn is_dirty(&self) -> bool;
14    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
15
16    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
17    fn head_commit(&self) -> std::rc::Rc<Commit>;
18    fn head_branch(&self) -> Option<Branch>;
19    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
20    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>>;
21    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
22    fn commit_range(
23        &self,
24        base_bound: std::ops::Bound<&git2::Oid>,
25        head_bound: std::ops::Bound<&git2::Oid>,
26    ) -> Result<Vec<git2::Oid>>;
27    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool>;
28    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid>;
29    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid>;
30    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid>;
31
32    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid>;
33    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()>;
34
35    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()>;
36    fn delete_branch(&mut self, name: &str) -> Result<()>;
37    fn find_local_branch(&self, name: &str) -> Option<Branch>;
38    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
39    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
40    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
41    fn detach(&mut self) -> Result<()>;
42    fn switch_branch(&mut self, name: &str) -> Result<()>;
43    fn switch_commit(&mut self, id: git2::Oid) -> Result<()>;
44}
45
46#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Branch {
48    pub remote: Option<String>,
49    pub name: String,
50    pub id: git2::Oid,
51}
52
53impl Branch {
54    pub fn local_name(&self) -> Option<&str> {
55        self.remote.is_none().then_some(self.name.as_str())
56    }
57}
58
59impl std::fmt::Display for Branch {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        if let Some(remote) = self.remote.as_deref() {
62            write!(f, "{}/{}", remote, self.name.as_str())
63        } else {
64            write!(f, "{}", self.name.as_str())
65        }
66    }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
70pub struct Commit {
71    pub id: git2::Oid,
72    pub tree_id: git2::Oid,
73    pub summary: bstr::BString,
74    pub time: std::time::SystemTime,
75    pub author: Option<std::rc::Rc<str>>,
76    pub committer: Option<std::rc::Rc<str>>,
77}
78
79impl Commit {
80    pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
81        self.summary
82            .strip_prefix(b"fixup! ")
83            .map(ByteSlice::as_bstr)
84    }
85
86    pub fn wip_summary(&self) -> Option<&bstr::BStr> {
87        // Gitlab MRs only: b"[Draft]", b"(Draft)",
88        static WIP_PREFIXES: &[&[u8]] = &[
89            b"WIP:", b"draft:", b"Draft:", // Gitlab commits
90            b"wip ", b"WIP ", // Less formal
91        ];
92
93        if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
94            // Very informal
95            Some(b"".as_bstr())
96        } else {
97            WIP_PREFIXES.iter().find_map(|prefix| {
98                self.summary
99                    .strip_prefix(*prefix)
100                    .map(ByteSlice::trim)
101                    .map(ByteSlice::as_bstr)
102            })
103        }
104    }
105
106    pub fn revert_summary(&self) -> Option<&bstr::BStr> {
107        self.summary
108            .strip_prefix(b"Revert ")
109            .and_then(|s| s.strip_suffix(b"\""))
110            .map(ByteSlice::as_bstr)
111    }
112}
113
114pub struct GitRepo {
115    repo: git2::Repository,
116    sign: Option<git2_ext::ops::UserSign>,
117    push_remote: Option<String>,
118    pull_remote: Option<String>,
119    commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
120    interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
121    bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
122    counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
123}
124
125impl GitRepo {
126    pub fn new(repo: git2::Repository) -> Self {
127        Self {
128            repo,
129            sign: None,
130            push_remote: None,
131            pull_remote: None,
132            commits: Default::default(),
133            interned_strings: Default::default(),
134            bases: Default::default(),
135            counts: Default::default(),
136        }
137    }
138
139    pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
140        if yes {
141            let config = self.repo.config()?;
142            let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
143            self.sign = Some(sign);
144        } else {
145            self.sign = None;
146        }
147        Ok(())
148    }
149
150    pub fn set_push_remote(&mut self, remote: &str) {
151        self.push_remote = Some(remote.to_owned());
152    }
153
154    pub fn set_pull_remote(&mut self, remote: &str) {
155        self.pull_remote = Some(remote.to_owned());
156    }
157
158    pub fn push_remote(&self) -> &str {
159        self.push_remote.as_deref().unwrap_or("origin")
160    }
161
162    pub fn pull_remote(&self) -> &str {
163        self.pull_remote.as_deref().unwrap_or("origin")
164    }
165
166    pub fn raw(&self) -> &git2::Repository {
167        &self.repo
168    }
169
170    pub fn user(&self) -> Option<std::rc::Rc<str>> {
171        self.repo
172            .signature()
173            .ok()
174            .and_then(|s| s.name().map(|n| self.intern_string(n)))
175    }
176
177    pub fn is_dirty(&self) -> bool {
178        let status = self
179            .repo
180            .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
181            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
182        if status.is_empty() {
183            false
184        } else {
185            log::trace!(
186                "Repository is dirty: {}",
187                status
188                    .iter()
189                    .filter_map(|s| s.path().map(|s| s.to_owned()))
190                    .join(", ")
191            );
192            true
193        }
194    }
195
196    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
197        if one == two {
198            return Some(one);
199        }
200
201        let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
202        *self
203            .bases
204            .borrow_mut()
205            .entry((smaller, larger))
206            .or_insert_with(|| self.merge_base_raw(smaller, larger))
207    }
208
209    fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
210        self.repo.merge_base(one, two).ok()
211    }
212
213    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
214        let mut commits = self.commits.borrow_mut();
215        if let Some(commit) = commits.get(&id) {
216            Some(std::rc::Rc::clone(commit))
217        } else {
218            let commit = self.repo.find_commit(id).ok()?;
219            let summary: bstr::BString = commit.summary_bytes().unwrap().into();
220            let time = std::time::SystemTime::UNIX_EPOCH
221                + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
222
223            let author = commit.author().name().map(|n| self.intern_string(n));
224            let committer = commit.author().name().map(|n| self.intern_string(n));
225            let commit = std::rc::Rc::new(Commit {
226                id: commit.id(),
227                tree_id: commit.tree_id(),
228                summary,
229                time,
230                author,
231                committer,
232            });
233            commits.insert(id, std::rc::Rc::clone(&commit));
234            Some(commit)
235        }
236    }
237
238    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
239        let head_id = self
240            .repo
241            .head()
242            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
243            .resolve()
244            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
245            .target()
246            .unwrap();
247        self.find_commit(head_id).unwrap()
248    }
249
250    pub fn head_branch(&self) -> Option<Branch> {
251        if self.repo.head_detached().unwrap_or(true) {
252            return None;
253        }
254
255        let resolved = self
256            .repo
257            .head()
258            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
259            .resolve()
260            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
261        let name = resolved.shorthand()?;
262        let id = resolved.target()?;
263
264        Some(Branch {
265            remote: None,
266            name: name.to_owned(),
267            id,
268        })
269    }
270
271    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
272        let id = self.repo.revparse_single(revspec).ok()?.id();
273        self.find_commit(id)
274    }
275
276    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
277        let commit = self.repo.find_commit(head_id)?;
278        Ok(commit.parent_ids().collect())
279    }
280
281    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
282        if base_id == head_id {
283            return Some(0);
284        }
285
286        *self
287            .counts
288            .borrow_mut()
289            .entry((base_id, head_id))
290            .or_insert_with(|| self.commit_count_raw(base_id, head_id))
291    }
292
293    fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
294        let merge_base_id = self.merge_base(base_id, head_id)?;
295        if merge_base_id != base_id {
296            return None;
297        }
298        let mut revwalk = self
299            .repo
300            .revwalk()
301            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
302        revwalk
303            .push(head_id)
304            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
305        revwalk
306            .hide(base_id)
307            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
308        Some(revwalk.count())
309    }
310
311    pub fn commit_range(
312        &self,
313        base_bound: std::ops::Bound<&git2::Oid>,
314        head_bound: std::ops::Bound<&git2::Oid>,
315    ) -> Result<Vec<git2::Oid>> {
316        let head_id = match head_bound {
317            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
318            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
319        };
320        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
321            0
322        } else {
323            1
324        };
325
326        let base_id = match base_bound {
327            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
328                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
329                Some(*base_id)
330            }
331            std::ops::Bound::Unbounded => None,
332        };
333
334        let mut revwalk = self.repo.revwalk()?;
335        revwalk.push(head_id)?;
336        if let Some(base_id) = base_id {
337            revwalk.hide(base_id)?;
338        }
339        revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
340        let mut result = revwalk
341            .filter_map(Result::ok)
342            .skip(skip)
343            .take_while(|id| Some(*id) != base_id)
344            .collect::<Vec<_>>();
345        if let std::ops::Bound::Included(base_id) = base_bound {
346            result.push(*base_id);
347        }
348        Ok(result)
349    }
350
351    pub fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
352        let needle_commit = self.repo.find_commit(needle_id)?;
353        let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
354        let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
355
356        let parent_ann_commit = if 0 < needle_commit.parent_count() {
357            let parent_commit = needle_commit.parent(0)?;
358            Some(self.repo.find_annotated_commit(parent_commit.id())?)
359        } else {
360            None
361        };
362
363        let mut rebase = self.repo.rebase(
364            Some(&needle_ann_commit),
365            parent_ann_commit.as_ref(),
366            Some(&haystack_ann_commit),
367            Some(git2::RebaseOptions::new().inmemory(true)),
368        )?;
369
370        if let Some(op) = rebase.next() {
371            op.map_err(|e| {
372                let _ = rebase.abort();
373                e
374            })?;
375            let inmemory_index = rebase
376                .inmemory_index()
377                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
378            if inmemory_index.has_conflicts() {
379                return Ok(false);
380            }
381
382            let sig = self
383                .repo
384                .signature()
385                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
386            let result = rebase.commit(None, &sig, None).map_err(|e| {
387                let _ = rebase.abort();
388                e
389            });
390            match result {
391                // Created commit, must be unique
392                Ok(_) => Ok(false),
393                Err(err) => {
394                    if err.class() == git2::ErrorClass::Rebase
395                        && err.code() == git2::ErrorCode::Applied
396                    {
397                        return Ok(true);
398                    }
399                    Err(err)
400                }
401            }
402        } else {
403            // No commit created, must exist somehow
404            rebase.finish(None)?;
405            Ok(true)
406        }
407    }
408
409    pub fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
410        git2_ext::ops::cherry_pick(
411            &self.repo,
412            head_id,
413            cherry_id,
414            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
415        )
416    }
417
418    pub fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
419        git2_ext::ops::reword(
420            &self.repo,
421            head_oid,
422            msg,
423            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
424        )
425    }
426
427    pub fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
428        git2_ext::ops::squash(
429            &self.repo,
430            head_id,
431            into_id,
432            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
433        )
434    }
435
436    pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
437        let signature = self.repo.signature()?;
438        self.repo.stash_save2(&signature, message, None)
439    }
440
441    pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
442        let mut index = None;
443        self.repo.stash_foreach(|i, _, id| {
444            if *id == stash_id {
445                index = Some(i);
446                false
447            } else {
448                true
449            }
450        })?;
451        let index = index.ok_or_else(|| {
452            Error::new(
453                git2::ErrorCode::NotFound,
454                git2::ErrorClass::Reference,
455                "stash ID not found",
456            )
457        })?;
458        self.repo.stash_pop(index, None)
459    }
460
461    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
462        let commit = self.repo.find_commit(id)?;
463        self.repo.branch(name, &commit, true)?;
464        Ok(())
465    }
466
467    pub fn delete_branch(&mut self, name: &str) -> Result<()> {
468        // HACK: We shouldn't limit ourselves to `Local`
469        let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
470        branch.delete()
471    }
472
473    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
474        let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
475        self.load_local_branch(&branch, name).ok()
476    }
477
478    pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
479        let qualified = format!("{remote}/{name}");
480        let branch = self
481            .repo
482            .find_branch(&qualified, git2::BranchType::Remote)
483            .ok()?;
484        self.load_remote_branch(&branch, remote, name).ok()
485    }
486
487    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
488        log::trace!("Loading local branches");
489        self.repo
490            .branches(Some(git2::BranchType::Local))
491            .into_iter()
492            .flatten()
493            .filter_map(move |branch| {
494                let (branch, _) = branch.ok()?;
495                let name = if let Some(name) = branch.name().ok().flatten() {
496                    name
497                } else {
498                    log::debug!(
499                        "Ignoring non-UTF8 branch {:?}",
500                        branch.name_bytes().unwrap().as_bstr()
501                    );
502                    return None;
503                };
504                self.load_local_branch(&branch, name).ok()
505            })
506    }
507
508    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
509        log::trace!("Loading remote branches");
510        self.repo
511            .branches(Some(git2::BranchType::Remote))
512            .into_iter()
513            .flatten()
514            .filter_map(move |branch| {
515                let (branch, _) = branch.ok()?;
516                let name = if let Some(name) = branch.name().ok().flatten() {
517                    name
518                } else {
519                    log::debug!(
520                        "Ignoring non-UTF8 branch {:?}",
521                        branch.name_bytes().unwrap().as_bstr()
522                    );
523                    return None;
524                };
525                let (remote, name) = name.split_once('/').unwrap();
526                self.load_remote_branch(&branch, remote, name).ok()
527            })
528    }
529
530    fn load_local_branch(&self, branch: &git2::Branch<'_>, name: &str) -> Result<Branch> {
531        let id = branch.get().target().unwrap();
532
533        Ok(Branch {
534            remote: None,
535            name: name.to_owned(),
536            id,
537        })
538    }
539
540    fn load_remote_branch(
541        &self,
542        branch: &git2::Branch<'_>,
543        remote: &str,
544        name: &str,
545    ) -> Result<Branch> {
546        let id = branch.get().target().unwrap();
547
548        Ok(Branch {
549            remote: Some(remote.to_owned()),
550            name: name.to_owned(),
551            id,
552        })
553    }
554
555    pub fn detach(&mut self) -> Result<()> {
556        let head_id = self
557            .repo
558            .head()
559            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
560            .resolve()
561            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
562            .target()
563            .unwrap();
564        self.repo.set_head_detached(head_id)?;
565        Ok(())
566    }
567
568    pub fn switch_branch(&mut self, name: &str) -> Result<()> {
569        let head_tree_id = self.head_commit().tree_id;
570        let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
571        let target_id = branch.get().target().unwrap();
572        let target_tree_id = self
573            .find_commit(target_id)
574            .map(|c| c.tree_id)
575            .unwrap_or_else(git2::Oid::zero);
576
577        self.repo.set_head(branch.get().name().unwrap())?;
578
579        if head_tree_id != target_tree_id {
580            let mut builder = git2::build::CheckoutBuilder::new();
581            builder.force();
582            self.repo.checkout_head(Some(&mut builder))?;
583        }
584
585        Ok(())
586    }
587
588    pub fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
589        let head_tree_id = self.head_commit().tree_id;
590        let target_tree_id = self
591            .find_commit(id)
592            .map(|c| c.tree_id)
593            .unwrap_or_else(git2::Oid::zero);
594
595        self.repo.set_head_detached(id)?;
596
597        if head_tree_id != target_tree_id {
598            let mut builder = git2::build::CheckoutBuilder::new();
599            builder.force();
600            self.repo.checkout_head(Some(&mut builder))?;
601        }
602        Ok(())
603    }
604
605    fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
606        let mut interned_strings = self.interned_strings.borrow_mut();
607        if let Some(interned) = interned_strings.get(data) {
608            std::rc::Rc::clone(interned)
609        } else {
610            let interned = std::rc::Rc::from(data);
611            interned_strings.insert(std::rc::Rc::clone(&interned));
612            interned
613        }
614    }
615}
616
617impl std::fmt::Debug for GitRepo {
618    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
619        f.debug_struct("GitRepo")
620            .field("repo", &self.repo.workdir())
621            .field("push_remote", &self.push_remote.as_deref())
622            .field("pull_remote", &self.pull_remote.as_deref())
623            .finish()
624    }
625}
626
627impl Repo for GitRepo {
628    fn path(&self) -> Option<&std::path::Path> {
629        Some(self.repo.path())
630    }
631    fn user(&self) -> Option<std::rc::Rc<str>> {
632        self.user()
633    }
634    fn push_remote(&self) -> &str {
635        self.push_remote()
636    }
637    fn pull_remote(&self) -> &str {
638        self.pull_remote()
639    }
640
641    fn is_dirty(&self) -> bool {
642        self.is_dirty()
643    }
644
645    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
646        self.merge_base(one, two)
647    }
648
649    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
650        self.find_commit(id)
651    }
652
653    fn head_commit(&self) -> std::rc::Rc<Commit> {
654        self.head_commit()
655    }
656
657    fn head_branch(&self) -> Option<Branch> {
658        self.head_branch()
659    }
660
661    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
662        self.resolve(revspec)
663    }
664
665    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
666        self.parent_ids(head_id)
667    }
668
669    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
670        self.commit_count(base_id, head_id)
671    }
672
673    fn commit_range(
674        &self,
675        base_bound: std::ops::Bound<&git2::Oid>,
676        head_bound: std::ops::Bound<&git2::Oid>,
677    ) -> Result<Vec<git2::Oid>> {
678        self.commit_range(base_bound, head_bound)
679    }
680
681    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
682        self.contains_commit(haystack_id, needle_id)
683    }
684
685    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
686        self.cherry_pick(head_id, cherry_id)
687    }
688
689    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
690        self.reword(head_oid, msg)
691    }
692
693    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
694        self.squash(head_id, into_id)
695    }
696
697    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
698        self.stash_push(message)
699    }
700
701    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
702        self.stash_pop(stash_id)
703    }
704
705    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
706        self.branch(name, id)
707    }
708
709    fn delete_branch(&mut self, name: &str) -> Result<()> {
710        self.delete_branch(name)
711    }
712
713    fn find_local_branch(&self, name: &str) -> Option<Branch> {
714        self.find_local_branch(name)
715    }
716
717    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
718        self.find_remote_branch(remote, name)
719    }
720
721    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
722        Box::new(self.local_branches())
723    }
724
725    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
726        Box::new(self.remote_branches())
727    }
728
729    fn detach(&mut self) -> Result<()> {
730        self.detach()
731    }
732
733    fn switch_branch(&mut self, name: &str) -> Result<()> {
734        self.switch_branch(name)
735    }
736
737    fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
738        self.switch_commit(id)
739    }
740}
741
742#[derive(Debug)]
743pub struct InMemoryRepo {
744    commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
745    branches: std::collections::HashMap<String, Branch>,
746    head_id: Option<git2::Oid>,
747
748    last_id: std::sync::atomic::AtomicUsize,
749}
750
751impl InMemoryRepo {
752    pub fn new() -> Self {
753        Self {
754            commits: Default::default(),
755            branches: Default::default(),
756            head_id: Default::default(),
757            last_id: std::sync::atomic::AtomicUsize::new(1),
758        }
759    }
760
761    pub fn clear(&mut self) {
762        *self = InMemoryRepo::new();
763    }
764
765    pub fn gen_id(&mut self) -> git2::Oid {
766        let last_id = self
767            .last_id
768            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
769        let sha = format!("{last_id:040x}");
770        git2::Oid::from_str(&sha).unwrap()
771    }
772
773    pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
774        if let Some(parent_id) = parent_id {
775            assert!(self.commits.contains_key(&parent_id));
776        }
777        self.head_id = Some(commit.id);
778        self.commits
779            .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
780    }
781
782    pub fn head_id(&mut self) -> Option<git2::Oid> {
783        self.head_id
784    }
785
786    pub fn set_head(&mut self, head_id: git2::Oid) {
787        assert!(self.commits.contains_key(&head_id));
788        self.head_id = Some(head_id);
789    }
790
791    pub fn mark_branch(&mut self, branch: Branch) {
792        assert!(self.commits.contains_key(&branch.id));
793        self.branches.insert(branch.name.clone(), branch);
794    }
795
796    pub fn push_remote(&self) -> &str {
797        "origin"
798    }
799
800    pub fn pull_remote(&self) -> &str {
801        "origin"
802    }
803
804    fn user(&self) -> Option<std::rc::Rc<str>> {
805        None
806    }
807
808    pub fn is_dirty(&self) -> bool {
809        false
810    }
811
812    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
813        let one_ancestors: Vec<_> = self.commits_from(one).collect();
814        self.commits_from(two)
815            .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
816            .map(|c| c.id)
817            .next()
818    }
819
820    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
821        self.commits.get(&id).map(|c| c.1.clone())
822    }
823
824    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
825        self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
826    }
827
828    pub fn head_branch(&self) -> Option<Branch> {
829        self.branches
830            .values()
831            .find(|b| b.id == self.head_id.unwrap())
832            .cloned()
833    }
834
835    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
836        let branch = self.branches.get(revspec)?;
837        self.find_commit(branch.id)
838    }
839
840    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
841        let next = self
842            .commits
843            .get(&head_id)
844            .and_then(|(parent, _commit)| *parent);
845        Ok(next.into_iter().collect())
846    }
847
848    fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
849        let next = self.commits.get(&head_id).cloned();
850        CommitsFrom {
851            commits: &self.commits,
852            next,
853        }
854    }
855
856    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
857        let merge_base_id = self.merge_base(base_id, head_id)?;
858        let count = self
859            .commits_from(head_id)
860            .take_while(move |cur_id| cur_id.id != merge_base_id)
861            .count();
862        Some(count)
863    }
864
865    pub fn commit_range(
866        &self,
867        base_bound: std::ops::Bound<&git2::Oid>,
868        head_bound: std::ops::Bound<&git2::Oid>,
869    ) -> Result<Vec<git2::Oid>> {
870        let head_id = match head_bound {
871            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
872            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
873        };
874        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
875            0
876        } else {
877            1
878        };
879
880        let base_id = match base_bound {
881            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
882                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
883                Some(*base_id)
884            }
885            std::ops::Bound::Unbounded => None,
886        };
887
888        let mut result = self
889            .commits_from(head_id)
890            .skip(skip)
891            .map(|commit| commit.id)
892            .take_while(|id| Some(*id) != base_id)
893            .collect::<Vec<_>>();
894        if let std::ops::Bound::Included(base_id) = base_bound {
895            result.push(*base_id);
896        }
897        Ok(result)
898    }
899
900    pub fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
901        // Because we don't have the information for likeness matches, just checking for Oid
902        let mut next = Some(haystack_id);
903        while let Some(current) = next {
904            if current == needle_id {
905                return Ok(true);
906            }
907            next = self.commits.get(&current).and_then(|c| c.0);
908        }
909        Ok(false)
910    }
911
912    pub fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
913        let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
914            Error::new(
915                git2::ErrorCode::NotFound,
916                git2::ErrorClass::Reference,
917                format!("could not find commit {cherry_id:?}"),
918            )
919        })?;
920        let mut cherry_commit = Commit::clone(&cherry_commit);
921        let new_id = self.gen_id();
922        cherry_commit.id = new_id;
923        self.commits
924            .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
925        Ok(new_id)
926    }
927
928    pub fn reword(&mut self, head_id: git2::Oid, msg: &str) -> Result<git2::Oid> {
929        let (head_parent, head_commit) = self.commits.get(&head_id).cloned().ok_or_else(|| {
930            git2::Error::new(
931                git2::ErrorCode::NotFound,
932                git2::ErrorClass::Reference,
933                format!("could not find commit {head_id:?}"),
934            )
935        })?;
936
937        let mut reworded_commit = Commit::clone(&head_commit);
938        let new_id = self.gen_id();
939        reworded_commit.id = new_id;
940        reworded_commit.summary = msg.into();
941        self.commits
942            .insert(new_id, (head_parent, std::rc::Rc::new(reworded_commit)));
943        Ok(new_id)
944    }
945
946    pub fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
947        self.commits.get(&head_id).cloned().ok_or_else(|| {
948            Error::new(
949                git2::ErrorCode::NotFound,
950                git2::ErrorClass::Reference,
951                format!("could not find commit {head_id:?}"),
952            )
953        })?;
954        let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
955            Error::new(
956                git2::ErrorCode::NotFound,
957                git2::ErrorClass::Reference,
958                format!("could not find commit {into_id:?}"),
959            )
960        })?;
961
962        let mut squashed_commit = Commit::clone(&into_commit);
963        let new_id = self.gen_id();
964        squashed_commit.id = new_id;
965        self.commits
966            .insert(new_id, (intos_parent, std::rc::Rc::new(squashed_commit)));
967        Ok(new_id)
968    }
969
970    pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid> {
971        Err(Error::new(
972            git2::ErrorCode::NotFound,
973            git2::ErrorClass::Reference,
974            "stash is unsupported",
975        ))
976    }
977
978    pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<()> {
979        Err(Error::new(
980            git2::ErrorCode::NotFound,
981            git2::ErrorClass::Reference,
982            "stash is unsupported",
983        ))
984    }
985
986    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
987        self.branches.insert(
988            name.to_owned(),
989            Branch {
990                remote: None,
991                name: name.to_owned(),
992                id,
993            },
994        );
995        Ok(())
996    }
997
998    pub fn delete_branch(&mut self, name: &str) -> Result<()> {
999        self.branches.remove(name).map(|_| ()).ok_or_else(|| {
1000            Error::new(
1001                git2::ErrorCode::NotFound,
1002                git2::ErrorClass::Reference,
1003                format!("could not remove branch {name:?}"),
1004            )
1005        })
1006    }
1007
1008    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1009        self.branches.get(name).cloned()
1010    }
1011
1012    pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1013        None
1014    }
1015
1016    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1017        self.branches.values().cloned()
1018    }
1019
1020    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1021        None.into_iter()
1022    }
1023
1024    pub fn detach(&mut self) -> Result<()> {
1025        Ok(())
1026    }
1027
1028    pub fn switch_branch(&mut self, name: &str) -> Result<()> {
1029        let branch = self.find_local_branch(name).ok_or_else(|| {
1030            Error::new(
1031                git2::ErrorCode::NotFound,
1032                git2::ErrorClass::Reference,
1033                format!("could not find branch {name:?}"),
1034            )
1035        })?;
1036        self.head_id = Some(branch.id);
1037        Ok(())
1038    }
1039
1040    pub fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
1041        self.head_id = Some(id);
1042        Ok(())
1043    }
1044}
1045
1046impl Default for InMemoryRepo {
1047    fn default() -> Self {
1048        Self::new()
1049    }
1050}
1051
1052struct CommitsFrom<'c> {
1053    commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1054    next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1055}
1056
1057impl Iterator for CommitsFrom<'_> {
1058    type Item = std::rc::Rc<Commit>;
1059
1060    fn next(&mut self) -> Option<Self::Item> {
1061        let mut current = None;
1062        std::mem::swap(&mut current, &mut self.next);
1063        let current = current?;
1064        if let Some(parent_id) = current.0 {
1065            self.next = self.commits.get(&parent_id).cloned();
1066        }
1067        Some(current.1)
1068    }
1069}
1070
1071impl Repo for InMemoryRepo {
1072    fn path(&self) -> Option<&std::path::Path> {
1073        None
1074    }
1075    fn user(&self) -> Option<std::rc::Rc<str>> {
1076        self.user()
1077    }
1078    fn push_remote(&self) -> &str {
1079        self.push_remote()
1080    }
1081    fn pull_remote(&self) -> &str {
1082        self.pull_remote()
1083    }
1084
1085    fn is_dirty(&self) -> bool {
1086        self.is_dirty()
1087    }
1088
1089    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1090        self.merge_base(one, two)
1091    }
1092
1093    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1094        self.find_commit(id)
1095    }
1096
1097    fn head_commit(&self) -> std::rc::Rc<Commit> {
1098        self.head_commit()
1099    }
1100
1101    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1102        self.resolve(revspec)
1103    }
1104
1105    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
1106        self.parent_ids(head_id)
1107    }
1108
1109    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1110        self.commit_count(base_id, head_id)
1111    }
1112
1113    fn commit_range(
1114        &self,
1115        base_bound: std::ops::Bound<&git2::Oid>,
1116        head_bound: std::ops::Bound<&git2::Oid>,
1117    ) -> Result<Vec<git2::Oid>> {
1118        self.commit_range(base_bound, head_bound)
1119    }
1120
1121    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
1122        self.contains_commit(haystack_id, needle_id)
1123    }
1124
1125    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
1126        self.cherry_pick(head_id, cherry_id)
1127    }
1128
1129    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
1130        self.reword(head_oid, msg)
1131    }
1132
1133    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
1134        self.squash(head_id, into_id)
1135    }
1136
1137    fn head_branch(&self) -> Option<Branch> {
1138        self.head_branch()
1139    }
1140
1141    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
1142        self.stash_push(message)
1143    }
1144
1145    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
1146        self.stash_pop(stash_id)
1147    }
1148
1149    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
1150        self.branch(name, id)
1151    }
1152
1153    fn delete_branch(&mut self, name: &str) -> Result<()> {
1154        self.delete_branch(name)
1155    }
1156
1157    fn find_local_branch(&self, name: &str) -> Option<Branch> {
1158        self.find_local_branch(name)
1159    }
1160
1161    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1162        self.find_remote_branch(remote, name)
1163    }
1164
1165    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1166        Box::new(self.local_branches())
1167    }
1168
1169    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1170        Box::new(self.remote_branches())
1171    }
1172
1173    fn detach(&mut self) -> Result<()> {
1174        self.detach()
1175    }
1176
1177    fn switch_branch(&mut self, name: &str) -> Result<()> {
1178        self.switch_branch(name)
1179    }
1180
1181    fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
1182        self.switch_commit(id)
1183    }
1184}
1185
1186pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1187    let branch = repo.head_branch();
1188    if !repo.is_dirty() {
1189        log::debug!("Nothing to stash");
1190        return None;
1191    }
1192
1193    let stash_msg = format!(
1194        "WIP on {} ({})",
1195        branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1196        context
1197    );
1198    match repo.stash_push(Some(&stash_msg)) {
1199        Ok(stash_id) => {
1200            log::info!(
1201                "Saved working directory and index state {}: {}",
1202                stash_msg,
1203                stash_id
1204            );
1205            Some(stash_id)
1206        }
1207        Err(err) => {
1208            log::debug!("Failed to stash: {}", err);
1209            None
1210        }
1211    }
1212}
1213
1214pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1215    if let Some(stash_id) = stash_id {
1216        match repo.stash_pop(stash_id) {
1217            Ok(()) => {
1218                log::info!("Dropped refs/stash {}", stash_id);
1219            }
1220            Err(err) => {
1221                log::error!("Failed to pop {} from stash: {}", stash_id, err);
1222            }
1223        }
1224    }
1225}
1226
1227pub fn commit_range(
1228    repo: &dyn Repo,
1229    head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1230) -> Result<Vec<git2::Oid>> {
1231    let head_bound = head_to_base.start_bound();
1232    let base_bound = head_to_base.end_bound();
1233    repo.commit_range(base_bound, head_bound)
1234}