git_stack/git/
repo.rs

1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub type Error = git2::Error;
5pub type Result<T, E = Error> = std::result::Result<T, E>;
6
7pub trait Repo {
8    fn path(&self) -> Option<&std::path::Path>;
9    fn user(&self) -> Option<std::rc::Rc<str>>;
10    fn push_remote(&self) -> &str;
11    fn pull_remote(&self) -> &str;
12
13    fn is_dirty(&self) -> bool;
14    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
15
16    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
17    fn head_commit(&self) -> std::rc::Rc<Commit>;
18    fn head_branch(&self) -> Option<Branch>;
19    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
20    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>>;
21    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
22    fn commit_range(
23        &self,
24        base_bound: std::ops::Bound<&git2::Oid>,
25        head_bound: std::ops::Bound<&git2::Oid>,
26    ) -> Result<Vec<git2::Oid>>;
27    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool>;
28    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid>;
29    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid>;
30    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid>;
31
32    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid>;
33    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()>;
34
35    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()>;
36    fn delete_branch(&mut self, name: &str) -> Result<()>;
37    fn find_local_branch(&self, name: &str) -> Option<Branch>;
38    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
39    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
40    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
41    fn detach(&mut self) -> Result<()>;
42    fn switch_branch(&mut self, name: &str) -> Result<()>;
43    fn switch_commit(&mut self, id: git2::Oid) -> Result<()>;
44}
45
46#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Branch {
48    pub remote: Option<String>,
49    pub name: String,
50    pub id: git2::Oid,
51}
52
53impl Branch {
54    pub fn local_name(&self) -> Option<&str> {
55        self.remote.is_none().then_some(self.name.as_str())
56    }
57}
58
59impl std::fmt::Display for Branch {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        if let Some(remote) = self.remote.as_deref() {
62            write!(f, "{}/{}", remote, self.name.as_str())
63        } else {
64            write!(f, "{}", self.name.as_str())
65        }
66    }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
70pub struct Commit {
71    pub id: git2::Oid,
72    pub tree_id: git2::Oid,
73    pub summary: bstr::BString,
74    pub time: std::time::SystemTime,
75    pub author: Option<std::rc::Rc<str>>,
76    pub committer: Option<std::rc::Rc<str>>,
77}
78
79impl Commit {
80    pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
81        self.summary
82            .strip_prefix(b"fixup! ")
83            .map(ByteSlice::as_bstr)
84    }
85
86    pub fn wip_summary(&self) -> Option<&bstr::BStr> {
87        // Gitlab MRs only: b"[Draft]", b"(Draft)",
88        static WIP_PREFIXES: &[&[u8]] = &[
89            b"WIP:", b"draft:", b"Draft:", // Gitlab commits
90            b"wip ", b"WIP ", // Less formal
91        ];
92
93        if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
94            // Very informal
95            Some(b"".as_bstr())
96        } else {
97            WIP_PREFIXES.iter().find_map(|prefix| {
98                self.summary
99                    .strip_prefix(*prefix)
100                    .map(ByteSlice::trim)
101                    .map(ByteSlice::as_bstr)
102            })
103        }
104    }
105
106    pub fn revert_summary(&self) -> Option<&bstr::BStr> {
107        self.summary
108            .strip_prefix(b"Revert ")
109            .and_then(|s| s.strip_suffix(b"\""))
110            .map(ByteSlice::as_bstr)
111    }
112}
113
114pub struct GitRepo {
115    repo: git2::Repository,
116    sign: Option<git2_ext::ops::UserSign>,
117    push_remote: Option<String>,
118    pull_remote: Option<String>,
119    commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
120    interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
121    bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
122    counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
123}
124
125impl GitRepo {
126    pub fn new(repo: git2::Repository) -> Self {
127        Self {
128            repo,
129            sign: None,
130            push_remote: None,
131            pull_remote: None,
132            commits: Default::default(),
133            interned_strings: Default::default(),
134            bases: Default::default(),
135            counts: Default::default(),
136        }
137    }
138
139    pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
140        if yes {
141            let config = self.repo.config()?;
142            let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
143            self.sign = Some(sign);
144        } else {
145            self.sign = None;
146        }
147        Ok(())
148    }
149
150    pub fn set_push_remote(&mut self, remote: &str) {
151        self.push_remote = Some(remote.to_owned());
152    }
153
154    pub fn set_pull_remote(&mut self, remote: &str) {
155        self.pull_remote = Some(remote.to_owned());
156    }
157
158    pub fn push_remote(&self) -> &str {
159        self.push_remote.as_deref().unwrap_or("origin")
160    }
161
162    pub fn pull_remote(&self) -> &str {
163        self.pull_remote.as_deref().unwrap_or("origin")
164    }
165
166    pub fn raw(&self) -> &git2::Repository {
167        &self.repo
168    }
169
170    pub fn user(&self) -> Option<std::rc::Rc<str>> {
171        self.repo
172            .signature()
173            .ok()
174            .and_then(|s| s.name().map(|n| self.intern_string(n)))
175    }
176
177    pub fn is_dirty(&self) -> bool {
178        let status = self
179            .repo
180            .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
181            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
182        if status.is_empty() {
183            false
184        } else {
185            log::trace!(
186                "Repository is dirty: {}",
187                status
188                    .iter()
189                    .filter_map(|s| s.path().map(|s| s.to_owned()))
190                    .join(", ")
191            );
192            true
193        }
194    }
195
196    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
197        if one == two {
198            return Some(one);
199        }
200
201        let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
202        *self
203            .bases
204            .borrow_mut()
205            .entry((smaller, larger))
206            .or_insert_with(|| self.merge_base_raw(smaller, larger))
207    }
208
209    fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
210        self.repo.merge_base(one, two).ok()
211    }
212
213    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
214        let mut commits = self.commits.borrow_mut();
215        if let Some(commit) = commits.get(&id) {
216            Some(std::rc::Rc::clone(commit))
217        } else {
218            let commit = self.repo.find_commit(id).ok()?;
219            let summary: bstr::BString = commit.summary_bytes().unwrap().into();
220            let time = std::time::SystemTime::UNIX_EPOCH
221                + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
222
223            let author = commit.author().name().map(|n| self.intern_string(n));
224            let committer = commit.author().name().map(|n| self.intern_string(n));
225            let commit = std::rc::Rc::new(Commit {
226                id: commit.id(),
227                tree_id: commit.tree_id(),
228                summary,
229                time,
230                author,
231                committer,
232            });
233            commits.insert(id, std::rc::Rc::clone(&commit));
234            Some(commit)
235        }
236    }
237
238    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
239        let head_id = self
240            .repo
241            .head()
242            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
243            .resolve()
244            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
245            .target()
246            .unwrap();
247        self.find_commit(head_id).unwrap()
248    }
249
250    pub fn head_branch(&self) -> Option<Branch> {
251        if self.repo.head_detached().unwrap_or(true) {
252            return None;
253        }
254
255        let resolved = self
256            .repo
257            .head()
258            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
259            .resolve()
260            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
261        let name = resolved.shorthand()?;
262        let id = resolved.target()?;
263
264        Some(Branch {
265            remote: None,
266            name: name.to_owned(),
267            id,
268        })
269    }
270
271    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
272        let id = self.repo.revparse_single(revspec).ok()?.id();
273        self.find_commit(id)
274    }
275
276    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
277        let commit = self.repo.find_commit(head_id)?;
278        Ok(commit.parent_ids().collect())
279    }
280
281    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
282        if base_id == head_id {
283            return Some(0);
284        }
285
286        *self
287            .counts
288            .borrow_mut()
289            .entry((base_id, head_id))
290            .or_insert_with(|| self.commit_count_raw(base_id, head_id))
291    }
292
293    fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
294        let merge_base_id = self.merge_base(base_id, head_id)?;
295        if merge_base_id != base_id {
296            return None;
297        }
298        let mut revwalk = self
299            .repo
300            .revwalk()
301            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
302        revwalk
303            .push(head_id)
304            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
305        revwalk
306            .hide(base_id)
307            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
308        Some(revwalk.count())
309    }
310
311    pub fn commit_range(
312        &self,
313        base_bound: std::ops::Bound<&git2::Oid>,
314        head_bound: std::ops::Bound<&git2::Oid>,
315    ) -> Result<Vec<git2::Oid>> {
316        let head_id = match head_bound {
317            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
318            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
319        };
320        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
321            0
322        } else {
323            1
324        };
325
326        let base_id = match base_bound {
327            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
328                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
329                Some(*base_id)
330            }
331            std::ops::Bound::Unbounded => None,
332        };
333
334        let mut revwalk = self.repo.revwalk()?;
335        revwalk.push(head_id)?;
336        if let Some(base_id) = base_id {
337            revwalk.hide(base_id)?;
338        }
339        revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
340        let mut result = revwalk
341            .filter_map(Result::ok)
342            .skip(skip)
343            .take_while(|id| Some(*id) != base_id)
344            .collect::<Vec<_>>();
345        if let std::ops::Bound::Included(base_id) = base_bound {
346            result.push(*base_id);
347        }
348        Ok(result)
349    }
350
351    pub fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
352        let needle_commit = self.repo.find_commit(needle_id)?;
353        let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
354        let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
355
356        let parent_ann_commit = if 0 < needle_commit.parent_count() {
357            let parent_commit = needle_commit.parent(0)?;
358            Some(self.repo.find_annotated_commit(parent_commit.id())?)
359        } else {
360            None
361        };
362
363        let mut rebase = self.repo.rebase(
364            Some(&needle_ann_commit),
365            parent_ann_commit.as_ref(),
366            Some(&haystack_ann_commit),
367            Some(git2::RebaseOptions::new().inmemory(true)),
368        )?;
369
370        if let Some(op) = rebase.next() {
371            op.inspect_err(|_e| {
372                let _ = rebase.abort();
373            })?;
374            let inmemory_index = rebase
375                .inmemory_index()
376                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
377            if inmemory_index.has_conflicts() {
378                return Ok(false);
379            }
380
381            let sig = self
382                .repo
383                .signature()
384                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
385            let result = rebase.commit(None, &sig, None).inspect_err(|_e| {
386                let _ = rebase.abort();
387            });
388            match result {
389                // Created commit, must be unique
390                Ok(_) => Ok(false),
391                Err(err) => {
392                    if err.class() == git2::ErrorClass::Rebase
393                        && err.code() == git2::ErrorCode::Applied
394                    {
395                        return Ok(true);
396                    }
397                    Err(err)
398                }
399            }
400        } else {
401            // No commit created, must exist somehow
402            rebase.finish(None)?;
403            Ok(true)
404        }
405    }
406
407    pub fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
408        git2_ext::ops::cherry_pick(
409            &self.repo,
410            head_id,
411            cherry_id,
412            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
413        )
414    }
415
416    pub fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
417        git2_ext::ops::reword(
418            &self.repo,
419            head_oid,
420            msg,
421            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
422        )
423    }
424
425    pub fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
426        git2_ext::ops::squash(
427            &self.repo,
428            head_id,
429            into_id,
430            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
431        )
432    }
433
434    pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
435        let signature = self.repo.signature()?;
436        self.repo.stash_save2(&signature, message, None)
437    }
438
439    pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
440        let mut index = None;
441        self.repo.stash_foreach(|i, _, id| {
442            if *id == stash_id {
443                index = Some(i);
444                false
445            } else {
446                true
447            }
448        })?;
449        let index = index.ok_or_else(|| {
450            Error::new(
451                git2::ErrorCode::NotFound,
452                git2::ErrorClass::Reference,
453                "stash ID not found",
454            )
455        })?;
456        self.repo.stash_pop(index, None)
457    }
458
459    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
460        let commit = self.repo.find_commit(id)?;
461        self.repo.branch(name, &commit, true)?;
462        Ok(())
463    }
464
465    pub fn delete_branch(&mut self, name: &str) -> Result<()> {
466        // HACK: We shouldn't limit ourselves to `Local`
467        let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
468        branch.delete()
469    }
470
471    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
472        let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
473        self.load_local_branch(&branch, name).ok()
474    }
475
476    pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
477        let qualified = format!("{remote}/{name}");
478        let branch = self
479            .repo
480            .find_branch(&qualified, git2::BranchType::Remote)
481            .ok()?;
482        self.load_remote_branch(&branch, remote, name).ok()
483    }
484
485    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
486        log::trace!("Loading local branches");
487        self.repo
488            .branches(Some(git2::BranchType::Local))
489            .into_iter()
490            .flatten()
491            .filter_map(move |branch| {
492                let (branch, _) = branch.ok()?;
493                let name = if let Some(name) = branch.name().ok().flatten() {
494                    name
495                } else {
496                    log::debug!(
497                        "Ignoring non-UTF8 branch {:?}",
498                        branch.name_bytes().unwrap().as_bstr()
499                    );
500                    return None;
501                };
502                self.load_local_branch(&branch, name).ok()
503            })
504    }
505
506    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
507        log::trace!("Loading remote branches");
508        self.repo
509            .branches(Some(git2::BranchType::Remote))
510            .into_iter()
511            .flatten()
512            .filter_map(move |branch| {
513                let (branch, _) = branch.ok()?;
514                let name = if let Some(name) = branch.name().ok().flatten() {
515                    name
516                } else {
517                    log::debug!(
518                        "Ignoring non-UTF8 branch {:?}",
519                        branch.name_bytes().unwrap().as_bstr()
520                    );
521                    return None;
522                };
523                let (remote, name) = name.split_once('/').unwrap();
524                self.load_remote_branch(&branch, remote, name).ok()
525            })
526    }
527
528    fn load_local_branch(&self, branch: &git2::Branch<'_>, name: &str) -> Result<Branch> {
529        let id = branch.get().target().unwrap();
530
531        Ok(Branch {
532            remote: None,
533            name: name.to_owned(),
534            id,
535        })
536    }
537
538    fn load_remote_branch(
539        &self,
540        branch: &git2::Branch<'_>,
541        remote: &str,
542        name: &str,
543    ) -> Result<Branch> {
544        let id = branch.get().target().unwrap();
545
546        Ok(Branch {
547            remote: Some(remote.to_owned()),
548            name: name.to_owned(),
549            id,
550        })
551    }
552
553    pub fn detach(&mut self) -> Result<()> {
554        let head_id = self
555            .repo
556            .head()
557            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
558            .resolve()
559            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
560            .target()
561            .unwrap();
562        self.repo.set_head_detached(head_id)?;
563        Ok(())
564    }
565
566    pub fn switch_branch(&mut self, name: &str) -> Result<()> {
567        let head_tree_id = self.head_commit().tree_id;
568        let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
569        let target_id = branch.get().target().unwrap();
570        let target_tree_id = self
571            .find_commit(target_id)
572            .map(|c| c.tree_id)
573            .unwrap_or_else(git2::Oid::zero);
574
575        self.repo.set_head(branch.get().name().unwrap())?;
576
577        if head_tree_id != target_tree_id {
578            let mut builder = git2::build::CheckoutBuilder::new();
579            builder.force();
580            self.repo.checkout_head(Some(&mut builder))?;
581        }
582
583        Ok(())
584    }
585
586    pub fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
587        let head_tree_id = self.head_commit().tree_id;
588        let target_tree_id = self
589            .find_commit(id)
590            .map(|c| c.tree_id)
591            .unwrap_or_else(git2::Oid::zero);
592
593        self.repo.set_head_detached(id)?;
594
595        if head_tree_id != target_tree_id {
596            let mut builder = git2::build::CheckoutBuilder::new();
597            builder.force();
598            self.repo.checkout_head(Some(&mut builder))?;
599        }
600        Ok(())
601    }
602
603    fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
604        let mut interned_strings = self.interned_strings.borrow_mut();
605        if let Some(interned) = interned_strings.get(data) {
606            std::rc::Rc::clone(interned)
607        } else {
608            let interned = std::rc::Rc::from(data);
609            interned_strings.insert(std::rc::Rc::clone(&interned));
610            interned
611        }
612    }
613}
614
615impl std::fmt::Debug for GitRepo {
616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
617        f.debug_struct("GitRepo")
618            .field("repo", &self.repo.workdir())
619            .field("push_remote", &self.push_remote.as_deref())
620            .field("pull_remote", &self.pull_remote.as_deref())
621            .finish()
622    }
623}
624
625impl Repo for GitRepo {
626    fn path(&self) -> Option<&std::path::Path> {
627        Some(self.repo.path())
628    }
629    fn user(&self) -> Option<std::rc::Rc<str>> {
630        self.user()
631    }
632    fn push_remote(&self) -> &str {
633        self.push_remote()
634    }
635    fn pull_remote(&self) -> &str {
636        self.pull_remote()
637    }
638
639    fn is_dirty(&self) -> bool {
640        self.is_dirty()
641    }
642
643    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
644        self.merge_base(one, two)
645    }
646
647    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
648        self.find_commit(id)
649    }
650
651    fn head_commit(&self) -> std::rc::Rc<Commit> {
652        self.head_commit()
653    }
654
655    fn head_branch(&self) -> Option<Branch> {
656        self.head_branch()
657    }
658
659    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
660        self.resolve(revspec)
661    }
662
663    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
664        self.parent_ids(head_id)
665    }
666
667    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
668        self.commit_count(base_id, head_id)
669    }
670
671    fn commit_range(
672        &self,
673        base_bound: std::ops::Bound<&git2::Oid>,
674        head_bound: std::ops::Bound<&git2::Oid>,
675    ) -> Result<Vec<git2::Oid>> {
676        self.commit_range(base_bound, head_bound)
677    }
678
679    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
680        self.contains_commit(haystack_id, needle_id)
681    }
682
683    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
684        self.cherry_pick(head_id, cherry_id)
685    }
686
687    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
688        self.reword(head_oid, msg)
689    }
690
691    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
692        self.squash(head_id, into_id)
693    }
694
695    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
696        self.stash_push(message)
697    }
698
699    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
700        self.stash_pop(stash_id)
701    }
702
703    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
704        self.branch(name, id)
705    }
706
707    fn delete_branch(&mut self, name: &str) -> Result<()> {
708        self.delete_branch(name)
709    }
710
711    fn find_local_branch(&self, name: &str) -> Option<Branch> {
712        self.find_local_branch(name)
713    }
714
715    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
716        self.find_remote_branch(remote, name)
717    }
718
719    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
720        Box::new(self.local_branches())
721    }
722
723    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
724        Box::new(self.remote_branches())
725    }
726
727    fn detach(&mut self) -> Result<()> {
728        self.detach()
729    }
730
731    fn switch_branch(&mut self, name: &str) -> Result<()> {
732        self.switch_branch(name)
733    }
734
735    fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
736        self.switch_commit(id)
737    }
738}
739
740#[derive(Debug)]
741pub struct InMemoryRepo {
742    commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
743    branches: std::collections::HashMap<String, Branch>,
744    head_id: Option<git2::Oid>,
745
746    last_id: std::sync::atomic::AtomicUsize,
747}
748
749impl InMemoryRepo {
750    pub fn new() -> Self {
751        Self {
752            commits: Default::default(),
753            branches: Default::default(),
754            head_id: Default::default(),
755            last_id: std::sync::atomic::AtomicUsize::new(1),
756        }
757    }
758
759    pub fn clear(&mut self) {
760        *self = InMemoryRepo::new();
761    }
762
763    pub fn gen_id(&mut self) -> git2::Oid {
764        let last_id = self
765            .last_id
766            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
767        let sha = format!("{last_id:040x}");
768        git2::Oid::from_str(&sha).unwrap()
769    }
770
771    pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
772        if let Some(parent_id) = parent_id {
773            assert!(self.commits.contains_key(&parent_id));
774        }
775        self.head_id = Some(commit.id);
776        self.commits
777            .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
778    }
779
780    pub fn head_id(&mut self) -> Option<git2::Oid> {
781        self.head_id
782    }
783
784    pub fn set_head(&mut self, head_id: git2::Oid) {
785        assert!(self.commits.contains_key(&head_id));
786        self.head_id = Some(head_id);
787    }
788
789    pub fn mark_branch(&mut self, branch: Branch) {
790        assert!(self.commits.contains_key(&branch.id));
791        self.branches.insert(branch.name.clone(), branch);
792    }
793
794    pub fn push_remote(&self) -> &str {
795        "origin"
796    }
797
798    pub fn pull_remote(&self) -> &str {
799        "origin"
800    }
801
802    fn user(&self) -> Option<std::rc::Rc<str>> {
803        None
804    }
805
806    pub fn is_dirty(&self) -> bool {
807        false
808    }
809
810    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
811        let one_ancestors: Vec<_> = self.commits_from(one).collect();
812        self.commits_from(two)
813            .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
814            .map(|c| c.id)
815            .next()
816    }
817
818    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
819        self.commits.get(&id).map(|c| c.1.clone())
820    }
821
822    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
823        self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
824    }
825
826    pub fn head_branch(&self) -> Option<Branch> {
827        self.branches
828            .values()
829            .find(|b| b.id == self.head_id.unwrap())
830            .cloned()
831    }
832
833    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
834        let branch = self.branches.get(revspec)?;
835        self.find_commit(branch.id)
836    }
837
838    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
839        let next = self
840            .commits
841            .get(&head_id)
842            .and_then(|(parent, _commit)| *parent);
843        Ok(next.into_iter().collect())
844    }
845
846    fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
847        let next = self.commits.get(&head_id).cloned();
848        CommitsFrom {
849            commits: &self.commits,
850            next,
851        }
852    }
853
854    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
855        let merge_base_id = self.merge_base(base_id, head_id)?;
856        let count = self
857            .commits_from(head_id)
858            .take_while(move |cur_id| cur_id.id != merge_base_id)
859            .count();
860        Some(count)
861    }
862
863    pub fn commit_range(
864        &self,
865        base_bound: std::ops::Bound<&git2::Oid>,
866        head_bound: std::ops::Bound<&git2::Oid>,
867    ) -> Result<Vec<git2::Oid>> {
868        let head_id = match head_bound {
869            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
870            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
871        };
872        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
873            0
874        } else {
875            1
876        };
877
878        let base_id = match base_bound {
879            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
880                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
881                Some(*base_id)
882            }
883            std::ops::Bound::Unbounded => None,
884        };
885
886        let mut result = self
887            .commits_from(head_id)
888            .skip(skip)
889            .map(|commit| commit.id)
890            .take_while(|id| Some(*id) != base_id)
891            .collect::<Vec<_>>();
892        if let std::ops::Bound::Included(base_id) = base_bound {
893            result.push(*base_id);
894        }
895        Ok(result)
896    }
897
898    pub fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
899        // Because we don't have the information for likeness matches, just checking for Oid
900        let mut next = Some(haystack_id);
901        while let Some(current) = next {
902            if current == needle_id {
903                return Ok(true);
904            }
905            next = self.commits.get(&current).and_then(|c| c.0);
906        }
907        Ok(false)
908    }
909
910    pub fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
911        let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
912            Error::new(
913                git2::ErrorCode::NotFound,
914                git2::ErrorClass::Reference,
915                format!("could not find commit {cherry_id:?}"),
916            )
917        })?;
918        let mut cherry_commit = Commit::clone(&cherry_commit);
919        let new_id = self.gen_id();
920        cherry_commit.id = new_id;
921        self.commits
922            .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
923        Ok(new_id)
924    }
925
926    pub fn reword(&mut self, head_id: git2::Oid, msg: &str) -> Result<git2::Oid> {
927        let (head_parent, head_commit) = self.commits.get(&head_id).cloned().ok_or_else(|| {
928            git2::Error::new(
929                git2::ErrorCode::NotFound,
930                git2::ErrorClass::Reference,
931                format!("could not find commit {head_id:?}"),
932            )
933        })?;
934
935        let mut reworded_commit = Commit::clone(&head_commit);
936        let new_id = self.gen_id();
937        reworded_commit.id = new_id;
938        reworded_commit.summary = msg.into();
939        self.commits
940            .insert(new_id, (head_parent, std::rc::Rc::new(reworded_commit)));
941        Ok(new_id)
942    }
943
944    pub fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
945        self.commits.get(&head_id).cloned().ok_or_else(|| {
946            Error::new(
947                git2::ErrorCode::NotFound,
948                git2::ErrorClass::Reference,
949                format!("could not find commit {head_id:?}"),
950            )
951        })?;
952        let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
953            Error::new(
954                git2::ErrorCode::NotFound,
955                git2::ErrorClass::Reference,
956                format!("could not find commit {into_id:?}"),
957            )
958        })?;
959
960        let mut squashed_commit = Commit::clone(&into_commit);
961        let new_id = self.gen_id();
962        squashed_commit.id = new_id;
963        self.commits
964            .insert(new_id, (intos_parent, std::rc::Rc::new(squashed_commit)));
965        Ok(new_id)
966    }
967
968    pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid> {
969        Err(Error::new(
970            git2::ErrorCode::NotFound,
971            git2::ErrorClass::Reference,
972            "stash is unsupported",
973        ))
974    }
975
976    pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<()> {
977        Err(Error::new(
978            git2::ErrorCode::NotFound,
979            git2::ErrorClass::Reference,
980            "stash is unsupported",
981        ))
982    }
983
984    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
985        self.branches.insert(
986            name.to_owned(),
987            Branch {
988                remote: None,
989                name: name.to_owned(),
990                id,
991            },
992        );
993        Ok(())
994    }
995
996    pub fn delete_branch(&mut self, name: &str) -> Result<()> {
997        self.branches.remove(name).map(|_| ()).ok_or_else(|| {
998            Error::new(
999                git2::ErrorCode::NotFound,
1000                git2::ErrorClass::Reference,
1001                format!("could not remove branch {name:?}"),
1002            )
1003        })
1004    }
1005
1006    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1007        self.branches.get(name).cloned()
1008    }
1009
1010    pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1011        None
1012    }
1013
1014    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1015        self.branches.values().cloned()
1016    }
1017
1018    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1019        None.into_iter()
1020    }
1021
1022    pub fn detach(&mut self) -> Result<()> {
1023        Ok(())
1024    }
1025
1026    pub fn switch_branch(&mut self, name: &str) -> Result<()> {
1027        let branch = self.find_local_branch(name).ok_or_else(|| {
1028            Error::new(
1029                git2::ErrorCode::NotFound,
1030                git2::ErrorClass::Reference,
1031                format!("could not find branch {name:?}"),
1032            )
1033        })?;
1034        self.head_id = Some(branch.id);
1035        Ok(())
1036    }
1037
1038    pub fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
1039        self.head_id = Some(id);
1040        Ok(())
1041    }
1042}
1043
1044impl Default for InMemoryRepo {
1045    fn default() -> Self {
1046        Self::new()
1047    }
1048}
1049
1050struct CommitsFrom<'c> {
1051    commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1052    next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1053}
1054
1055impl Iterator for CommitsFrom<'_> {
1056    type Item = std::rc::Rc<Commit>;
1057
1058    fn next(&mut self) -> Option<Self::Item> {
1059        let mut current = None;
1060        std::mem::swap(&mut current, &mut self.next);
1061        let current = current?;
1062        if let Some(parent_id) = current.0 {
1063            self.next = self.commits.get(&parent_id).cloned();
1064        }
1065        Some(current.1)
1066    }
1067}
1068
1069impl Repo for InMemoryRepo {
1070    fn path(&self) -> Option<&std::path::Path> {
1071        None
1072    }
1073    fn user(&self) -> Option<std::rc::Rc<str>> {
1074        self.user()
1075    }
1076    fn push_remote(&self) -> &str {
1077        self.push_remote()
1078    }
1079    fn pull_remote(&self) -> &str {
1080        self.pull_remote()
1081    }
1082
1083    fn is_dirty(&self) -> bool {
1084        self.is_dirty()
1085    }
1086
1087    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1088        self.merge_base(one, two)
1089    }
1090
1091    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1092        self.find_commit(id)
1093    }
1094
1095    fn head_commit(&self) -> std::rc::Rc<Commit> {
1096        self.head_commit()
1097    }
1098
1099    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1100        self.resolve(revspec)
1101    }
1102
1103    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>> {
1104        self.parent_ids(head_id)
1105    }
1106
1107    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1108        self.commit_count(base_id, head_id)
1109    }
1110
1111    fn commit_range(
1112        &self,
1113        base_bound: std::ops::Bound<&git2::Oid>,
1114        head_bound: std::ops::Bound<&git2::Oid>,
1115    ) -> Result<Vec<git2::Oid>> {
1116        self.commit_range(base_bound, head_bound)
1117    }
1118
1119    fn contains_commit(&self, haystack_id: git2::Oid, needle_id: git2::Oid) -> Result<bool> {
1120        self.contains_commit(haystack_id, needle_id)
1121    }
1122
1123    fn cherry_pick(&mut self, head_id: git2::Oid, cherry_id: git2::Oid) -> Result<git2::Oid> {
1124        self.cherry_pick(head_id, cherry_id)
1125    }
1126
1127    fn reword(&mut self, head_oid: git2::Oid, msg: &str) -> Result<git2::Oid> {
1128        self.reword(head_oid, msg)
1129    }
1130
1131    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid> {
1132        self.squash(head_id, into_id)
1133    }
1134
1135    fn head_branch(&self) -> Option<Branch> {
1136        self.head_branch()
1137    }
1138
1139    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid> {
1140        self.stash_push(message)
1141    }
1142
1143    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<()> {
1144        self.stash_pop(stash_id)
1145    }
1146
1147    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<()> {
1148        self.branch(name, id)
1149    }
1150
1151    fn delete_branch(&mut self, name: &str) -> Result<()> {
1152        self.delete_branch(name)
1153    }
1154
1155    fn find_local_branch(&self, name: &str) -> Option<Branch> {
1156        self.find_local_branch(name)
1157    }
1158
1159    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1160        self.find_remote_branch(remote, name)
1161    }
1162
1163    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1164        Box::new(self.local_branches())
1165    }
1166
1167    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1168        Box::new(self.remote_branches())
1169    }
1170
1171    fn detach(&mut self) -> Result<()> {
1172        self.detach()
1173    }
1174
1175    fn switch_branch(&mut self, name: &str) -> Result<()> {
1176        self.switch_branch(name)
1177    }
1178
1179    fn switch_commit(&mut self, id: git2::Oid) -> Result<()> {
1180        self.switch_commit(id)
1181    }
1182}
1183
1184pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1185    let branch = repo.head_branch();
1186    if !repo.is_dirty() {
1187        log::debug!("Nothing to stash");
1188        return None;
1189    }
1190
1191    let stash_msg = format!(
1192        "WIP on {} ({})",
1193        branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1194        context
1195    );
1196    match repo.stash_push(Some(&stash_msg)) {
1197        Ok(stash_id) => {
1198            log::info!(
1199                "Saved working directory and index state {}: {}",
1200                stash_msg,
1201                stash_id
1202            );
1203            Some(stash_id)
1204        }
1205        Err(err) => {
1206            log::debug!("Failed to stash: {}", err);
1207            None
1208        }
1209    }
1210}
1211
1212pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1213    if let Some(stash_id) = stash_id {
1214        match repo.stash_pop(stash_id) {
1215            Ok(()) => {
1216                log::info!("Dropped refs/stash {}", stash_id);
1217            }
1218            Err(err) => {
1219                log::error!("Failed to pop {} from stash: {}", stash_id, err);
1220            }
1221        }
1222    }
1223}
1224
1225pub fn commit_range(
1226    repo: &dyn Repo,
1227    head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1228) -> Result<Vec<git2::Oid>> {
1229    let head_bound = head_to_base.start_bound();
1230    let base_bound = head_to_base.end_bound();
1231    repo.commit_range(base_bound, head_bound)
1232}