git_stack/legacy/git/
repo.rs

1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub trait Repo {
5    fn path(&self) -> Option<&std::path::Path>;
6    fn user(&self) -> Option<std::rc::Rc<str>>;
7
8    fn is_dirty(&self) -> bool;
9    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
10
11    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
12    fn head_commit(&self) -> std::rc::Rc<Commit>;
13    fn head_branch(&self) -> Option<Branch>;
14    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
15    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error>;
16    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
17    fn commit_range(
18        &self,
19        base_bound: std::ops::Bound<&git2::Oid>,
20        head_bound: std::ops::Bound<&git2::Oid>,
21    ) -> Result<Vec<git2::Oid>, git2::Error>;
22    fn contains_commit(
23        &self,
24        haystack_id: git2::Oid,
25        needle_id: git2::Oid,
26    ) -> Result<bool, git2::Error>;
27    fn cherry_pick(
28        &mut self,
29        head_id: git2::Oid,
30        cherry_id: git2::Oid,
31    ) -> Result<git2::Oid, git2::Error>;
32    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error>;
33
34    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error>;
35    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error>;
36
37    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error>;
38    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error>;
39    fn find_local_branch(&self, name: &str) -> Option<Branch>;
40    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
41    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
42    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
43    fn detach(&mut self) -> Result<(), git2::Error>;
44    fn switch(&mut self, name: &str) -> Result<(), git2::Error>;
45}
46
47#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct Branch {
49    pub remote: Option<String>,
50    pub name: String,
51    pub id: git2::Oid,
52    pub push_id: Option<git2::Oid>,
53    pub pull_id: Option<git2::Oid>,
54}
55
56impl Branch {
57    pub fn local_name(&self) -> Option<&str> {
58        self.remote.is_none().then_some(self.name.as_str())
59    }
60}
61
62impl std::fmt::Display for Branch {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        if let Some(remote) = self.remote.as_deref() {
65            write!(f, "{}/{}", remote, self.name.as_str())
66        } else {
67            write!(f, "{}", self.name.as_str())
68        }
69    }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct Commit {
74    pub id: git2::Oid,
75    pub tree_id: git2::Oid,
76    pub summary: bstr::BString,
77    pub time: std::time::SystemTime,
78    pub author: Option<std::rc::Rc<str>>,
79    pub committer: Option<std::rc::Rc<str>>,
80}
81
82impl Commit {
83    pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
84        self.summary
85            .strip_prefix(b"fixup! ")
86            .map(ByteSlice::as_bstr)
87    }
88
89    pub fn wip_summary(&self) -> Option<&bstr::BStr> {
90        // Gitlab MRs only: b"[Draft]", b"(Draft)",
91        static WIP_PREFIXES: &[&[u8]] = &[
92            b"WIP:", b"draft:", b"Draft:", // Gitlab commits
93            b"wip ", b"WIP ", // Less formal
94        ];
95
96        if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
97            // Very informal
98            Some(b"".as_bstr())
99        } else {
100            WIP_PREFIXES.iter().find_map(|prefix| {
101                self.summary
102                    .strip_prefix(*prefix)
103                    .map(ByteSlice::trim)
104                    .map(ByteSlice::as_bstr)
105            })
106        }
107    }
108
109    pub fn revert_summary(&self) -> Option<&bstr::BStr> {
110        self.summary
111            .strip_prefix(b"Revert ")
112            .and_then(|s| s.strip_suffix(b"\""))
113            .map(ByteSlice::as_bstr)
114    }
115}
116
117pub struct GitRepo {
118    repo: git2::Repository,
119    sign: Option<git2_ext::ops::UserSign>,
120    push_remote: Option<String>,
121    pull_remote: Option<String>,
122    commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
123    interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
124    bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
125    counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
126}
127
128impl GitRepo {
129    pub fn new(repo: git2::Repository) -> Self {
130        Self {
131            repo,
132            sign: None,
133            push_remote: None,
134            pull_remote: None,
135            commits: Default::default(),
136            interned_strings: Default::default(),
137            bases: Default::default(),
138            counts: Default::default(),
139        }
140    }
141
142    pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
143        if yes {
144            let config = self.repo.config()?;
145            let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
146            self.sign = Some(sign);
147        } else {
148            self.sign = None;
149        }
150        Ok(())
151    }
152
153    pub fn set_push_remote(&mut self, remote: &str) {
154        self.push_remote = Some(remote.to_owned());
155    }
156
157    pub fn set_pull_remote(&mut self, remote: &str) {
158        self.pull_remote = Some(remote.to_owned());
159    }
160
161    pub fn push_remote(&self) -> &str {
162        self.push_remote.as_deref().unwrap_or("origin")
163    }
164
165    pub fn pull_remote(&self) -> &str {
166        self.pull_remote.as_deref().unwrap_or("origin")
167    }
168
169    pub fn raw(&self) -> &git2::Repository {
170        &self.repo
171    }
172
173    pub fn user(&self) -> Option<std::rc::Rc<str>> {
174        self.repo
175            .signature()
176            .ok()
177            .and_then(|s| s.name().map(|n| self.intern_string(n)))
178    }
179
180    pub fn is_dirty(&self) -> bool {
181        if self.repo.state() != git2::RepositoryState::Clean {
182            log::trace!("Repository status is unclean: {:?}", self.repo.state());
183            return true;
184        }
185
186        let status = self
187            .repo
188            .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
189            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
190        if status.is_empty() {
191            false
192        } else {
193            log::trace!(
194                "Repository is dirty: {}",
195                status
196                    .iter()
197                    .filter_map(|s| s.path().map(|s| s.to_owned()))
198                    .join(", ")
199            );
200            true
201        }
202    }
203
204    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
205        if one == two {
206            return Some(one);
207        }
208
209        let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
210        *self
211            .bases
212            .borrow_mut()
213            .entry((smaller, larger))
214            .or_insert_with(|| self.merge_base_raw(smaller, larger))
215    }
216
217    fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
218        self.repo.merge_base(one, two).ok()
219    }
220
221    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
222        let mut commits = self.commits.borrow_mut();
223        if let Some(commit) = commits.get(&id) {
224            Some(std::rc::Rc::clone(commit))
225        } else {
226            let commit = self.repo.find_commit(id).ok()?;
227            let summary: bstr::BString = commit.summary_bytes().unwrap().into();
228            let time = std::time::SystemTime::UNIX_EPOCH
229                + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
230
231            let author = commit.author().name().map(|n| self.intern_string(n));
232            let committer = commit.author().name().map(|n| self.intern_string(n));
233            let commit = std::rc::Rc::new(Commit {
234                id: commit.id(),
235                tree_id: commit.tree_id(),
236                summary,
237                time,
238                author,
239                committer,
240            });
241            commits.insert(id, std::rc::Rc::clone(&commit));
242            Some(commit)
243        }
244    }
245
246    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
247        let head_id = self
248            .repo
249            .head()
250            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
251            .resolve()
252            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
253            .target()
254            .unwrap();
255        self.find_commit(head_id).unwrap()
256    }
257
258    pub fn head_branch(&self) -> Option<Branch> {
259        let resolved = self
260            .repo
261            .head()
262            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
263            .resolve()
264            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
265        let name = resolved.shorthand()?;
266        let id = resolved.target()?;
267
268        let push_id = self
269            .repo
270            .find_branch(
271                &format!("{}/{}", self.push_remote(), name),
272                git2::BranchType::Remote,
273            )
274            .ok()
275            .and_then(|b| b.get().target());
276        let pull_id = self
277            .repo
278            .find_branch(
279                &format!("{}/{}", self.pull_remote(), name),
280                git2::BranchType::Remote,
281            )
282            .ok()
283            .and_then(|b| b.get().target());
284
285        Some(Branch {
286            remote: None,
287            name: name.to_owned(),
288            id,
289            push_id,
290            pull_id,
291        })
292    }
293
294    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
295        let id = self.repo.revparse_single(revspec).ok()?.id();
296        self.find_commit(id)
297    }
298
299    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
300        let commit = self.repo.find_commit(head_id)?;
301        Ok(commit.parent_ids().collect())
302    }
303
304    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
305        if base_id == head_id {
306            return Some(0);
307        }
308
309        *self
310            .counts
311            .borrow_mut()
312            .entry((base_id, head_id))
313            .or_insert_with(|| self.commit_count_raw(base_id, head_id))
314    }
315
316    fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
317        let merge_base_id = self.merge_base(base_id, head_id)?;
318        if merge_base_id != base_id {
319            return None;
320        }
321        let mut revwalk = self
322            .repo
323            .revwalk()
324            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
325        revwalk
326            .push(head_id)
327            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
328        revwalk
329            .hide(base_id)
330            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
331        Some(revwalk.count())
332    }
333
334    pub fn commit_range(
335        &self,
336        base_bound: std::ops::Bound<&git2::Oid>,
337        head_bound: std::ops::Bound<&git2::Oid>,
338    ) -> Result<Vec<git2::Oid>, git2::Error> {
339        let head_id = match head_bound {
340            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
341            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
342        };
343        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
344            0
345        } else {
346            1
347        };
348
349        let base_id = match base_bound {
350            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
351                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
352                Some(*base_id)
353            }
354            std::ops::Bound::Unbounded => None,
355        };
356
357        let mut revwalk = self.repo.revwalk()?;
358        revwalk.push(head_id)?;
359        if let Some(base_id) = base_id {
360            revwalk.hide(base_id)?;
361        }
362        revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
363        let mut result = revwalk
364            .filter_map(Result::ok)
365            .skip(skip)
366            .take_while(|id| Some(*id) != base_id)
367            .collect::<Vec<_>>();
368        if let std::ops::Bound::Included(base_id) = base_bound {
369            result.push(*base_id);
370        }
371        Ok(result)
372    }
373
374    pub fn contains_commit(
375        &self,
376        haystack_id: git2::Oid,
377        needle_id: git2::Oid,
378    ) -> Result<bool, git2::Error> {
379        let needle_commit = self.repo.find_commit(needle_id)?;
380        let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
381        let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
382
383        let parent_ann_commit = if 0 < needle_commit.parent_count() {
384            let parent_commit = needle_commit.parent(0)?;
385            Some(self.repo.find_annotated_commit(parent_commit.id())?)
386        } else {
387            None
388        };
389
390        let mut rebase = self.repo.rebase(
391            Some(&needle_ann_commit),
392            parent_ann_commit.as_ref(),
393            Some(&haystack_ann_commit),
394            Some(git2::RebaseOptions::new().inmemory(true)),
395        )?;
396
397        if let Some(op) = rebase.next() {
398            op.inspect_err(|_e| {
399                let _ = rebase.abort();
400            })?;
401            let inmemory_index = rebase
402                .inmemory_index()
403                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
404            if inmemory_index.has_conflicts() {
405                return Ok(false);
406            }
407
408            let sig = self
409                .repo
410                .signature()
411                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
412            let result = rebase.commit(None, &sig, None).inspect_err(|_e| {
413                let _ = rebase.abort();
414            });
415            match result {
416                // Created commit, must be unique
417                Ok(_) => Ok(false),
418                Err(err) => {
419                    if err.class() == git2::ErrorClass::Rebase
420                        && err.code() == git2::ErrorCode::Applied
421                    {
422                        return Ok(true);
423                    }
424                    Err(err)
425                }
426            }
427        } else {
428            // No commit created, must exist somehow
429            rebase.finish(None)?;
430            Ok(true)
431        }
432    }
433
434    fn cherry_pick(
435        &mut self,
436        head_id: git2::Oid,
437        cherry_id: git2::Oid,
438    ) -> Result<git2::Oid, git2::Error> {
439        git2_ext::ops::cherry_pick(
440            &self.repo,
441            head_id,
442            cherry_id,
443            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
444        )
445    }
446
447    pub fn squash(
448        &mut self,
449        head_id: git2::Oid,
450        into_id: git2::Oid,
451    ) -> Result<git2::Oid, git2::Error> {
452        git2_ext::ops::squash(
453            &self.repo,
454            head_id,
455            into_id,
456            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
457        )
458    }
459
460    pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
461        let signature = self.repo.signature()?;
462        self.repo.stash_save2(&signature, message, None)
463    }
464
465    pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
466        let mut index = None;
467        self.repo.stash_foreach(|i, _, id| {
468            if *id == stash_id {
469                index = Some(i);
470                false
471            } else {
472                true
473            }
474        })?;
475        let index = index.ok_or_else(|| {
476            git2::Error::new(
477                git2::ErrorCode::NotFound,
478                git2::ErrorClass::Reference,
479                "stash ID not found",
480            )
481        })?;
482        self.repo.stash_pop(index, None)
483    }
484
485    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
486        let commit = self.repo.find_commit(id)?;
487        self.repo.branch(name, &commit, true)?;
488        Ok(())
489    }
490
491    pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
492        // HACK: We shouldn't limit ourselves to `Local`
493        let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
494        branch.delete()
495    }
496
497    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
498        let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
499        self.load_local_branch(&branch, name).ok()
500    }
501
502    pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
503        let qualified = format!("{remote}/{name}");
504        let branch = self
505            .repo
506            .find_branch(&qualified, git2::BranchType::Remote)
507            .ok()?;
508        self.load_remote_branch(&branch, remote, name).ok()
509    }
510
511    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
512        log::trace!("Loading local branches");
513        self.repo
514            .branches(Some(git2::BranchType::Local))
515            .into_iter()
516            .flatten()
517            .filter_map(move |branch| {
518                let (branch, _) = branch.ok()?;
519                let name = if let Some(name) = branch.name().ok().flatten() {
520                    name
521                } else {
522                    log::debug!(
523                        "Ignoring non-UTF8 branch {:?}",
524                        branch.name_bytes().unwrap().as_bstr()
525                    );
526                    return None;
527                };
528                self.load_local_branch(&branch, name).ok()
529            })
530    }
531
532    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
533        log::trace!("Loading remote branches");
534        self.repo
535            .branches(Some(git2::BranchType::Remote))
536            .into_iter()
537            .flatten()
538            .filter_map(move |branch| {
539                let (branch, _) = branch.ok()?;
540                let name = if let Some(name) = branch.name().ok().flatten() {
541                    name
542                } else {
543                    log::debug!(
544                        "Ignoring non-UTF8 branch {:?}",
545                        branch.name_bytes().unwrap().as_bstr()
546                    );
547                    return None;
548                };
549                let (remote, name) = name.split_once('/').unwrap();
550                self.load_remote_branch(&branch, remote, name).ok()
551            })
552    }
553
554    fn load_local_branch(
555        &self,
556        branch: &git2::Branch<'_>,
557        name: &str,
558    ) -> Result<Branch, git2::Error> {
559        let id = branch.get().target().unwrap();
560
561        let push_id = self
562            .repo
563            .find_branch(
564                &format!("{}/{}", self.push_remote(), name),
565                git2::BranchType::Remote,
566            )
567            .ok()
568            .and_then(|b| b.get().target());
569        let pull_id = self
570            .repo
571            .find_branch(
572                &format!("{}/{}", self.pull_remote(), name),
573                git2::BranchType::Remote,
574            )
575            .ok()
576            .and_then(|b| b.get().target());
577
578        Ok(Branch {
579            remote: None,
580            name: name.to_owned(),
581            id,
582            push_id,
583            pull_id,
584        })
585    }
586
587    fn load_remote_branch(
588        &self,
589        branch: &git2::Branch<'_>,
590        remote: &str,
591        name: &str,
592    ) -> Result<Branch, git2::Error> {
593        let id = branch.get().target().unwrap();
594
595        let push_id = (remote == self.push_remote()).then_some(id);
596        let pull_id = (remote == self.pull_remote()).then_some(id);
597
598        Ok(Branch {
599            remote: Some(remote.to_owned()),
600            name: name.to_owned(),
601            id,
602            push_id,
603            pull_id,
604        })
605    }
606
607    pub fn detach(&mut self) -> Result<(), git2::Error> {
608        let head_id = self
609            .repo
610            .head()
611            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
612            .resolve()
613            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
614            .target()
615            .unwrap();
616        self.repo.set_head_detached(head_id)?;
617        Ok(())
618    }
619
620    pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
621        // HACK: We shouldn't limit ourselves to `Local`
622        let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
623        self.repo.set_head(branch.get().name().unwrap())?;
624        let mut builder = git2::build::CheckoutBuilder::new();
625        builder.force();
626        self.repo.checkout_head(Some(&mut builder))?;
627        Ok(())
628    }
629
630    fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
631        let mut interned_strings = self.interned_strings.borrow_mut();
632        if let Some(interned) = interned_strings.get(data) {
633            std::rc::Rc::clone(interned)
634        } else {
635            let interned = std::rc::Rc::from(data);
636            interned_strings.insert(std::rc::Rc::clone(&interned));
637            interned
638        }
639    }
640}
641
642impl std::fmt::Debug for GitRepo {
643    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
644        f.debug_struct("GitRepo")
645            .field("repo", &self.repo.workdir())
646            .field("push_remote", &self.push_remote.as_deref())
647            .field("pull_remote", &self.pull_remote.as_deref())
648            .finish()
649    }
650}
651
652impl Repo for GitRepo {
653    fn path(&self) -> Option<&std::path::Path> {
654        Some(self.repo.path())
655    }
656    fn user(&self) -> Option<std::rc::Rc<str>> {
657        self.user()
658    }
659
660    fn is_dirty(&self) -> bool {
661        self.is_dirty()
662    }
663
664    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
665        self.merge_base(one, two)
666    }
667
668    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
669        self.find_commit(id)
670    }
671
672    fn head_commit(&self) -> std::rc::Rc<Commit> {
673        self.head_commit()
674    }
675
676    fn head_branch(&self) -> Option<Branch> {
677        self.head_branch()
678    }
679
680    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
681        self.resolve(revspec)
682    }
683
684    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
685        self.parent_ids(head_id)
686    }
687
688    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
689        self.commit_count(base_id, head_id)
690    }
691
692    fn commit_range(
693        &self,
694        base_bound: std::ops::Bound<&git2::Oid>,
695        head_bound: std::ops::Bound<&git2::Oid>,
696    ) -> Result<Vec<git2::Oid>, git2::Error> {
697        self.commit_range(base_bound, head_bound)
698    }
699
700    fn contains_commit(
701        &self,
702        haystack_id: git2::Oid,
703        needle_id: git2::Oid,
704    ) -> Result<bool, git2::Error> {
705        self.contains_commit(haystack_id, needle_id)
706    }
707
708    fn cherry_pick(
709        &mut self,
710        head_id: git2::Oid,
711        cherry_id: git2::Oid,
712    ) -> Result<git2::Oid, git2::Error> {
713        self.cherry_pick(head_id, cherry_id)
714    }
715
716    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
717        self.squash(head_id, into_id)
718    }
719
720    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
721        self.stash_push(message)
722    }
723
724    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
725        self.stash_pop(stash_id)
726    }
727
728    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
729        self.branch(name, id)
730    }
731
732    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
733        self.delete_branch(name)
734    }
735
736    fn find_local_branch(&self, name: &str) -> Option<Branch> {
737        self.find_local_branch(name)
738    }
739
740    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
741        self.find_remote_branch(remote, name)
742    }
743
744    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
745        Box::new(self.local_branches())
746    }
747
748    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
749        Box::new(self.remote_branches())
750    }
751
752    fn detach(&mut self) -> Result<(), git2::Error> {
753        self.detach()
754    }
755
756    fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
757        self.switch(name)
758    }
759}
760
761#[derive(Debug)]
762pub struct InMemoryRepo {
763    commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
764    branches: std::collections::HashMap<String, Branch>,
765    head_id: Option<git2::Oid>,
766
767    last_id: std::sync::atomic::AtomicUsize,
768}
769
770impl InMemoryRepo {
771    pub fn new() -> Self {
772        Self {
773            commits: Default::default(),
774            branches: Default::default(),
775            head_id: Default::default(),
776            last_id: std::sync::atomic::AtomicUsize::new(1),
777        }
778    }
779
780    pub fn clear(&mut self) {
781        *self = InMemoryRepo::new();
782    }
783
784    pub fn gen_id(&mut self) -> git2::Oid {
785        let last_id = self
786            .last_id
787            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
788        let sha = format!("{last_id:040x}");
789        git2::Oid::from_str(&sha).unwrap()
790    }
791
792    pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
793        if let Some(parent_id) = parent_id {
794            assert!(self.commits.contains_key(&parent_id));
795        }
796        self.head_id = Some(commit.id);
797        self.commits
798            .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
799    }
800
801    pub fn head_id(&mut self) -> Option<git2::Oid> {
802        self.head_id
803    }
804
805    pub fn set_head(&mut self, head_id: git2::Oid) {
806        assert!(self.commits.contains_key(&head_id));
807        self.head_id = Some(head_id);
808    }
809
810    pub fn mark_branch(&mut self, branch: Branch) {
811        assert!(self.commits.contains_key(&branch.id));
812        self.branches.insert(branch.name.clone(), branch);
813    }
814
815    fn user(&self) -> Option<std::rc::Rc<str>> {
816        None
817    }
818
819    pub fn is_dirty(&self) -> bool {
820        false
821    }
822
823    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
824        let one_ancestors: Vec<_> = self.commits_from(one).collect();
825        self.commits_from(two)
826            .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
827            .map(|c| c.id)
828            .next()
829    }
830
831    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
832        self.commits.get(&id).map(|c| c.1.clone())
833    }
834
835    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
836        self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
837    }
838
839    pub fn head_branch(&self) -> Option<Branch> {
840        self.branches
841            .values()
842            .find(|b| b.id == self.head_id.unwrap())
843            .cloned()
844    }
845
846    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
847        let branch = self.branches.get(revspec)?;
848        self.find_commit(branch.id)
849    }
850
851    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
852        let next = self
853            .commits
854            .get(&head_id)
855            .and_then(|(parent, _commit)| *parent);
856        Ok(next.into_iter().collect())
857    }
858
859    fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
860        let next = self.commits.get(&head_id).cloned();
861        CommitsFrom {
862            commits: &self.commits,
863            next,
864        }
865    }
866
867    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
868        let merge_base_id = self.merge_base(base_id, head_id)?;
869        let count = self
870            .commits_from(head_id)
871            .take_while(move |cur_id| cur_id.id != merge_base_id)
872            .count();
873        Some(count)
874    }
875
876    pub fn commit_range(
877        &self,
878        base_bound: std::ops::Bound<&git2::Oid>,
879        head_bound: std::ops::Bound<&git2::Oid>,
880    ) -> Result<Vec<git2::Oid>, git2::Error> {
881        let head_id = match head_bound {
882            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
883            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
884        };
885        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
886            0
887        } else {
888            1
889        };
890
891        let base_id = match base_bound {
892            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
893                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
894                Some(*base_id)
895            }
896            std::ops::Bound::Unbounded => None,
897        };
898
899        let mut result = self
900            .commits_from(head_id)
901            .skip(skip)
902            .map(|commit| commit.id)
903            .take_while(|id| Some(*id) != base_id)
904            .collect::<Vec<_>>();
905        if let std::ops::Bound::Included(base_id) = base_bound {
906            result.push(*base_id);
907        }
908        Ok(result)
909    }
910
911    pub fn contains_commit(
912        &self,
913        haystack_id: git2::Oid,
914        needle_id: git2::Oid,
915    ) -> Result<bool, git2::Error> {
916        // Because we don't have the information for likeness matches, just checking for Oid
917        let mut next = Some(haystack_id);
918        while let Some(current) = next {
919            if current == needle_id {
920                return Ok(true);
921            }
922            next = self.commits.get(&current).and_then(|c| c.0);
923        }
924        Ok(false)
925    }
926
927    pub fn cherry_pick(
928        &mut self,
929        head_id: git2::Oid,
930        cherry_id: git2::Oid,
931    ) -> Result<git2::Oid, git2::Error> {
932        let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
933            git2::Error::new(
934                git2::ErrorCode::NotFound,
935                git2::ErrorClass::Reference,
936                format!("could not find commit {cherry_id:?}"),
937            )
938        })?;
939        let mut cherry_commit = Commit::clone(&cherry_commit);
940        let new_id = self.gen_id();
941        cherry_commit.id = new_id;
942        self.commits
943            .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
944        Ok(new_id)
945    }
946
947    pub fn squash(
948        &mut self,
949        head_id: git2::Oid,
950        into_id: git2::Oid,
951    ) -> Result<git2::Oid, git2::Error> {
952        self.commits.get(&head_id).cloned().ok_or_else(|| {
953            git2::Error::new(
954                git2::ErrorCode::NotFound,
955                git2::ErrorClass::Reference,
956                format!("could not find commit {head_id:?}"),
957            )
958        })?;
959        let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
960            git2::Error::new(
961                git2::ErrorCode::NotFound,
962                git2::ErrorClass::Reference,
963                format!("could not find commit {into_id:?}"),
964            )
965        })?;
966        let intos_parent = intos_parent.unwrap();
967
968        let mut squashed_commit = Commit::clone(&into_commit);
969        let new_id = self.gen_id();
970        squashed_commit.id = new_id;
971        self.commits.insert(
972            new_id,
973            (Some(intos_parent), std::rc::Rc::new(squashed_commit)),
974        );
975        Ok(new_id)
976    }
977
978    pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid, git2::Error> {
979        Err(git2::Error::new(
980            git2::ErrorCode::NotFound,
981            git2::ErrorClass::Reference,
982            "stash is unsupported",
983        ))
984    }
985
986    pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<(), git2::Error> {
987        Err(git2::Error::new(
988            git2::ErrorCode::NotFound,
989            git2::ErrorClass::Reference,
990            "stash is unsupported",
991        ))
992    }
993
994    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
995        self.branches.insert(
996            name.to_owned(),
997            Branch {
998                remote: None,
999                name: name.to_owned(),
1000                id,
1001                push_id: None,
1002                pull_id: None,
1003            },
1004        );
1005        Ok(())
1006    }
1007
1008    pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1009        self.branches.remove(name).map(|_| ()).ok_or_else(|| {
1010            git2::Error::new(
1011                git2::ErrorCode::NotFound,
1012                git2::ErrorClass::Reference,
1013                format!("could not remove branch {name:?}"),
1014            )
1015        })
1016    }
1017
1018    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1019        self.branches.get(name).cloned()
1020    }
1021
1022    pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1023        None
1024    }
1025
1026    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1027        self.branches.values().cloned()
1028    }
1029
1030    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1031        None.into_iter()
1032    }
1033
1034    pub fn detach(&mut self) -> Result<(), git2::Error> {
1035        Ok(())
1036    }
1037
1038    pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1039        let branch = self.find_local_branch(name).ok_or_else(|| {
1040            git2::Error::new(
1041                git2::ErrorCode::NotFound,
1042                git2::ErrorClass::Reference,
1043                format!("could not find branch {name:?}"),
1044            )
1045        })?;
1046        self.head_id = Some(branch.id);
1047        Ok(())
1048    }
1049}
1050
1051impl Default for InMemoryRepo {
1052    fn default() -> Self {
1053        Self::new()
1054    }
1055}
1056
1057struct CommitsFrom<'c> {
1058    commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1059    next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1060}
1061
1062impl Iterator for CommitsFrom<'_> {
1063    type Item = std::rc::Rc<Commit>;
1064
1065    fn next(&mut self) -> Option<Self::Item> {
1066        let mut current = None;
1067        std::mem::swap(&mut current, &mut self.next);
1068        let current = current?;
1069        if let Some(parent_id) = current.0 {
1070            self.next = self.commits.get(&parent_id).cloned();
1071        }
1072        Some(current.1)
1073    }
1074}
1075
1076impl Repo for InMemoryRepo {
1077    fn path(&self) -> Option<&std::path::Path> {
1078        None
1079    }
1080    fn user(&self) -> Option<std::rc::Rc<str>> {
1081        self.user()
1082    }
1083
1084    fn is_dirty(&self) -> bool {
1085        self.is_dirty()
1086    }
1087
1088    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1089        self.merge_base(one, two)
1090    }
1091
1092    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1093        self.find_commit(id)
1094    }
1095
1096    fn head_commit(&self) -> std::rc::Rc<Commit> {
1097        self.head_commit()
1098    }
1099
1100    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1101        self.resolve(revspec)
1102    }
1103
1104    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
1105        self.parent_ids(head_id)
1106    }
1107
1108    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1109        self.commit_count(base_id, head_id)
1110    }
1111
1112    fn commit_range(
1113        &self,
1114        base_bound: std::ops::Bound<&git2::Oid>,
1115        head_bound: std::ops::Bound<&git2::Oid>,
1116    ) -> Result<Vec<git2::Oid>, git2::Error> {
1117        self.commit_range(base_bound, head_bound)
1118    }
1119
1120    fn contains_commit(
1121        &self,
1122        haystack_id: git2::Oid,
1123        needle_id: git2::Oid,
1124    ) -> Result<bool, git2::Error> {
1125        self.contains_commit(haystack_id, needle_id)
1126    }
1127
1128    fn cherry_pick(
1129        &mut self,
1130        head_id: git2::Oid,
1131        cherry_id: git2::Oid,
1132    ) -> Result<git2::Oid, git2::Error> {
1133        self.cherry_pick(head_id, cherry_id)
1134    }
1135
1136    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
1137        self.squash(head_id, into_id)
1138    }
1139
1140    fn head_branch(&self) -> Option<Branch> {
1141        self.head_branch()
1142    }
1143
1144    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
1145        self.stash_push(message)
1146    }
1147
1148    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
1149        self.stash_pop(stash_id)
1150    }
1151
1152    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
1153        self.branch(name, id)
1154    }
1155
1156    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1157        self.delete_branch(name)
1158    }
1159
1160    fn find_local_branch(&self, name: &str) -> Option<Branch> {
1161        self.find_local_branch(name)
1162    }
1163
1164    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1165        self.find_remote_branch(remote, name)
1166    }
1167
1168    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1169        Box::new(self.local_branches())
1170    }
1171
1172    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1173        Box::new(self.remote_branches())
1174    }
1175
1176    fn detach(&mut self) -> Result<(), git2::Error> {
1177        self.detach()
1178    }
1179
1180    fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1181        self.switch(name)
1182    }
1183}
1184
1185pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1186    let branch = repo.head_branch();
1187    let stash_msg = format!(
1188        "WIP on {} ({})",
1189        branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1190        context
1191    );
1192    match repo.stash_push(Some(&stash_msg)) {
1193        Ok(stash_id) => {
1194            log::info!(
1195                "Saved working directory and index state {}: {}",
1196                stash_msg,
1197                stash_id
1198            );
1199            Some(stash_id)
1200        }
1201        Err(err) => {
1202            log::debug!("Failed to stash: {}", err);
1203            None
1204        }
1205    }
1206}
1207
1208pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1209    if let Some(stash_id) = stash_id {
1210        match repo.stash_pop(stash_id) {
1211            Ok(()) => {
1212                log::info!("Dropped refs/stash {}", stash_id);
1213            }
1214            Err(err) => {
1215                log::error!("Failed to pop {} from stash: {}", stash_id, err);
1216            }
1217        }
1218    }
1219}
1220
1221pub fn commit_range(
1222    repo: &dyn Repo,
1223    head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1224) -> Result<Vec<git2::Oid>, git2::Error> {
1225    let head_bound = head_to_base.start_bound();
1226    let base_bound = head_to_base.end_bound();
1227    repo.commit_range(base_bound, head_bound)
1228}