git_stack/legacy/git/
repo.rs

1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub trait Repo {
5    fn path(&self) -> Option<&std::path::Path>;
6    fn user(&self) -> Option<std::rc::Rc<str>>;
7
8    fn is_dirty(&self) -> bool;
9    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
10
11    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
12    fn head_commit(&self) -> std::rc::Rc<Commit>;
13    fn head_branch(&self) -> Option<Branch>;
14    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
15    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error>;
16    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
17    fn commit_range(
18        &self,
19        base_bound: std::ops::Bound<&git2::Oid>,
20        head_bound: std::ops::Bound<&git2::Oid>,
21    ) -> Result<Vec<git2::Oid>, git2::Error>;
22    fn contains_commit(
23        &self,
24        haystack_id: git2::Oid,
25        needle_id: git2::Oid,
26    ) -> Result<bool, git2::Error>;
27    fn cherry_pick(
28        &mut self,
29        head_id: git2::Oid,
30        cherry_id: git2::Oid,
31    ) -> Result<git2::Oid, git2::Error>;
32    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error>;
33
34    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error>;
35    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error>;
36
37    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error>;
38    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error>;
39    fn find_local_branch(&self, name: &str) -> Option<Branch>;
40    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
41    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
42    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
43    fn detach(&mut self) -> Result<(), git2::Error>;
44    fn switch(&mut self, name: &str) -> Result<(), git2::Error>;
45}
46
47#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct Branch {
49    pub remote: Option<String>,
50    pub name: String,
51    pub id: git2::Oid,
52    pub push_id: Option<git2::Oid>,
53    pub pull_id: Option<git2::Oid>,
54}
55
56impl Branch {
57    pub fn local_name(&self) -> Option<&str> {
58        self.remote.is_none().then_some(self.name.as_str())
59    }
60}
61
62impl std::fmt::Display for Branch {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        if let Some(remote) = self.remote.as_deref() {
65            write!(f, "{}/{}", remote, self.name.as_str())
66        } else {
67            write!(f, "{}", self.name.as_str())
68        }
69    }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct Commit {
74    pub id: git2::Oid,
75    pub tree_id: git2::Oid,
76    pub summary: bstr::BString,
77    pub time: std::time::SystemTime,
78    pub author: Option<std::rc::Rc<str>>,
79    pub committer: Option<std::rc::Rc<str>>,
80}
81
82impl Commit {
83    pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
84        self.summary
85            .strip_prefix(b"fixup! ")
86            .map(ByteSlice::as_bstr)
87    }
88
89    pub fn wip_summary(&self) -> Option<&bstr::BStr> {
90        // Gitlab MRs only: b"[Draft]", b"(Draft)",
91        static WIP_PREFIXES: &[&[u8]] = &[
92            b"WIP:", b"draft:", b"Draft:", // Gitlab commits
93            b"wip ", b"WIP ", // Less formal
94        ];
95
96        if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
97            // Very informal
98            Some(b"".as_bstr())
99        } else {
100            WIP_PREFIXES.iter().find_map(|prefix| {
101                self.summary
102                    .strip_prefix(*prefix)
103                    .map(ByteSlice::trim)
104                    .map(ByteSlice::as_bstr)
105            })
106        }
107    }
108
109    pub fn revert_summary(&self) -> Option<&bstr::BStr> {
110        self.summary
111            .strip_prefix(b"Revert ")
112            .and_then(|s| s.strip_suffix(b"\""))
113            .map(ByteSlice::as_bstr)
114    }
115}
116
117pub struct GitRepo {
118    repo: git2::Repository,
119    sign: Option<git2_ext::ops::UserSign>,
120    push_remote: Option<String>,
121    pull_remote: Option<String>,
122    commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
123    interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
124    bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
125    counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
126}
127
128impl GitRepo {
129    pub fn new(repo: git2::Repository) -> Self {
130        Self {
131            repo,
132            sign: None,
133            push_remote: None,
134            pull_remote: None,
135            commits: Default::default(),
136            interned_strings: Default::default(),
137            bases: Default::default(),
138            counts: Default::default(),
139        }
140    }
141
142    pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
143        if yes {
144            let config = self.repo.config()?;
145            let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
146            self.sign = Some(sign);
147        } else {
148            self.sign = None;
149        }
150        Ok(())
151    }
152
153    pub fn set_push_remote(&mut self, remote: &str) {
154        self.push_remote = Some(remote.to_owned());
155    }
156
157    pub fn set_pull_remote(&mut self, remote: &str) {
158        self.pull_remote = Some(remote.to_owned());
159    }
160
161    pub fn push_remote(&self) -> &str {
162        self.push_remote.as_deref().unwrap_or("origin")
163    }
164
165    pub fn pull_remote(&self) -> &str {
166        self.pull_remote.as_deref().unwrap_or("origin")
167    }
168
169    pub fn raw(&self) -> &git2::Repository {
170        &self.repo
171    }
172
173    pub fn user(&self) -> Option<std::rc::Rc<str>> {
174        self.repo
175            .signature()
176            .ok()
177            .and_then(|s| s.name().map(|n| self.intern_string(n)))
178    }
179
180    pub fn is_dirty(&self) -> bool {
181        if self.repo.state() != git2::RepositoryState::Clean {
182            log::trace!("Repository status is unclean: {:?}", self.repo.state());
183            return true;
184        }
185
186        let status = self
187            .repo
188            .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
189            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
190        if status.is_empty() {
191            false
192        } else {
193            log::trace!(
194                "Repository is dirty: {}",
195                status
196                    .iter()
197                    .filter_map(|s| s.path().map(|s| s.to_owned()))
198                    .join(", ")
199            );
200            true
201        }
202    }
203
204    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
205        if one == two {
206            return Some(one);
207        }
208
209        let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
210        *self
211            .bases
212            .borrow_mut()
213            .entry((smaller, larger))
214            .or_insert_with(|| self.merge_base_raw(smaller, larger))
215    }
216
217    fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
218        self.repo.merge_base(one, two).ok()
219    }
220
221    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
222        let mut commits = self.commits.borrow_mut();
223        if let Some(commit) = commits.get(&id) {
224            Some(std::rc::Rc::clone(commit))
225        } else {
226            let commit = self.repo.find_commit(id).ok()?;
227            let summary: bstr::BString = commit.summary_bytes().unwrap().into();
228            let time = std::time::SystemTime::UNIX_EPOCH
229                + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
230
231            let author = commit.author().name().map(|n| self.intern_string(n));
232            let committer = commit.author().name().map(|n| self.intern_string(n));
233            let commit = std::rc::Rc::new(Commit {
234                id: commit.id(),
235                tree_id: commit.tree_id(),
236                summary,
237                time,
238                author,
239                committer,
240            });
241            commits.insert(id, std::rc::Rc::clone(&commit));
242            Some(commit)
243        }
244    }
245
246    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
247        let head_id = self
248            .repo
249            .head()
250            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
251            .resolve()
252            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
253            .target()
254            .unwrap();
255        self.find_commit(head_id).unwrap()
256    }
257
258    pub fn head_branch(&self) -> Option<Branch> {
259        let resolved = self
260            .repo
261            .head()
262            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
263            .resolve()
264            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
265        let name = resolved.shorthand()?;
266        let id = resolved.target()?;
267
268        let push_id = self
269            .repo
270            .find_branch(
271                &format!("{}/{}", self.push_remote(), name),
272                git2::BranchType::Remote,
273            )
274            .ok()
275            .and_then(|b| b.get().target());
276        let pull_id = self
277            .repo
278            .find_branch(
279                &format!("{}/{}", self.pull_remote(), name),
280                git2::BranchType::Remote,
281            )
282            .ok()
283            .and_then(|b| b.get().target());
284
285        Some(Branch {
286            remote: None,
287            name: name.to_owned(),
288            id,
289            push_id,
290            pull_id,
291        })
292    }
293
294    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
295        let id = self.repo.revparse_single(revspec).ok()?.id();
296        self.find_commit(id)
297    }
298
299    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
300        let commit = self.repo.find_commit(head_id)?;
301        Ok(commit.parent_ids().collect())
302    }
303
304    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
305        if base_id == head_id {
306            return Some(0);
307        }
308
309        *self
310            .counts
311            .borrow_mut()
312            .entry((base_id, head_id))
313            .or_insert_with(|| self.commit_count_raw(base_id, head_id))
314    }
315
316    fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
317        let merge_base_id = self.merge_base(base_id, head_id)?;
318        if merge_base_id != base_id {
319            return None;
320        }
321        let mut revwalk = self
322            .repo
323            .revwalk()
324            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
325        revwalk
326            .push(head_id)
327            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
328        revwalk
329            .hide(base_id)
330            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
331        Some(revwalk.count())
332    }
333
334    pub fn commit_range(
335        &self,
336        base_bound: std::ops::Bound<&git2::Oid>,
337        head_bound: std::ops::Bound<&git2::Oid>,
338    ) -> Result<Vec<git2::Oid>, git2::Error> {
339        let head_id = match head_bound {
340            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
341            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
342        };
343        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
344            0
345        } else {
346            1
347        };
348
349        let base_id = match base_bound {
350            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
351                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
352                Some(*base_id)
353            }
354            std::ops::Bound::Unbounded => None,
355        };
356
357        let mut revwalk = self.repo.revwalk()?;
358        revwalk.push(head_id)?;
359        if let Some(base_id) = base_id {
360            revwalk.hide(base_id)?;
361        }
362        revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
363        let mut result = revwalk
364            .filter_map(Result::ok)
365            .skip(skip)
366            .take_while(|id| Some(*id) != base_id)
367            .collect::<Vec<_>>();
368        if let std::ops::Bound::Included(base_id) = base_bound {
369            result.push(*base_id);
370        }
371        Ok(result)
372    }
373
374    pub fn contains_commit(
375        &self,
376        haystack_id: git2::Oid,
377        needle_id: git2::Oid,
378    ) -> Result<bool, git2::Error> {
379        let needle_commit = self.repo.find_commit(needle_id)?;
380        let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
381        let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
382
383        let parent_ann_commit = if 0 < needle_commit.parent_count() {
384            let parent_commit = needle_commit.parent(0)?;
385            Some(self.repo.find_annotated_commit(parent_commit.id())?)
386        } else {
387            None
388        };
389
390        let mut rebase = self.repo.rebase(
391            Some(&needle_ann_commit),
392            parent_ann_commit.as_ref(),
393            Some(&haystack_ann_commit),
394            Some(git2::RebaseOptions::new().inmemory(true)),
395        )?;
396
397        if let Some(op) = rebase.next() {
398            op.map_err(|e| {
399                let _ = rebase.abort();
400                e
401            })?;
402            let inmemory_index = rebase
403                .inmemory_index()
404                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
405            if inmemory_index.has_conflicts() {
406                return Ok(false);
407            }
408
409            let sig = self
410                .repo
411                .signature()
412                .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
413            let result = rebase.commit(None, &sig, None).map_err(|e| {
414                let _ = rebase.abort();
415                e
416            });
417            match result {
418                // Created commit, must be unique
419                Ok(_) => Ok(false),
420                Err(err) => {
421                    if err.class() == git2::ErrorClass::Rebase
422                        && err.code() == git2::ErrorCode::Applied
423                    {
424                        return Ok(true);
425                    }
426                    Err(err)
427                }
428            }
429        } else {
430            // No commit created, must exist somehow
431            rebase.finish(None)?;
432            Ok(true)
433        }
434    }
435
436    fn cherry_pick(
437        &mut self,
438        head_id: git2::Oid,
439        cherry_id: git2::Oid,
440    ) -> Result<git2::Oid, git2::Error> {
441        git2_ext::ops::cherry_pick(
442            &self.repo,
443            head_id,
444            cherry_id,
445            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
446        )
447    }
448
449    pub fn squash(
450        &mut self,
451        head_id: git2::Oid,
452        into_id: git2::Oid,
453    ) -> Result<git2::Oid, git2::Error> {
454        git2_ext::ops::squash(
455            &self.repo,
456            head_id,
457            into_id,
458            self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
459        )
460    }
461
462    pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
463        let signature = self.repo.signature()?;
464        self.repo.stash_save2(&signature, message, None)
465    }
466
467    pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
468        let mut index = None;
469        self.repo.stash_foreach(|i, _, id| {
470            if *id == stash_id {
471                index = Some(i);
472                false
473            } else {
474                true
475            }
476        })?;
477        let index = index.ok_or_else(|| {
478            git2::Error::new(
479                git2::ErrorCode::NotFound,
480                git2::ErrorClass::Reference,
481                "stash ID not found",
482            )
483        })?;
484        self.repo.stash_pop(index, None)
485    }
486
487    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
488        let commit = self.repo.find_commit(id)?;
489        self.repo.branch(name, &commit, true)?;
490        Ok(())
491    }
492
493    pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
494        // HACK: We shouldn't limit ourselves to `Local`
495        let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
496        branch.delete()
497    }
498
499    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
500        let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
501        self.load_local_branch(&branch, name).ok()
502    }
503
504    pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
505        let qualified = format!("{remote}/{name}");
506        let branch = self
507            .repo
508            .find_branch(&qualified, git2::BranchType::Remote)
509            .ok()?;
510        self.load_remote_branch(&branch, remote, name).ok()
511    }
512
513    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
514        log::trace!("Loading local branches");
515        self.repo
516            .branches(Some(git2::BranchType::Local))
517            .into_iter()
518            .flatten()
519            .filter_map(move |branch| {
520                let (branch, _) = branch.ok()?;
521                let name = if let Some(name) = branch.name().ok().flatten() {
522                    name
523                } else {
524                    log::debug!(
525                        "Ignoring non-UTF8 branch {:?}",
526                        branch.name_bytes().unwrap().as_bstr()
527                    );
528                    return None;
529                };
530                self.load_local_branch(&branch, name).ok()
531            })
532    }
533
534    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
535        log::trace!("Loading remote branches");
536        self.repo
537            .branches(Some(git2::BranchType::Remote))
538            .into_iter()
539            .flatten()
540            .filter_map(move |branch| {
541                let (branch, _) = branch.ok()?;
542                let name = if let Some(name) = branch.name().ok().flatten() {
543                    name
544                } else {
545                    log::debug!(
546                        "Ignoring non-UTF8 branch {:?}",
547                        branch.name_bytes().unwrap().as_bstr()
548                    );
549                    return None;
550                };
551                let (remote, name) = name.split_once('/').unwrap();
552                self.load_remote_branch(&branch, remote, name).ok()
553            })
554    }
555
556    fn load_local_branch(
557        &self,
558        branch: &git2::Branch<'_>,
559        name: &str,
560    ) -> Result<Branch, git2::Error> {
561        let id = branch.get().target().unwrap();
562
563        let push_id = self
564            .repo
565            .find_branch(
566                &format!("{}/{}", self.push_remote(), name),
567                git2::BranchType::Remote,
568            )
569            .ok()
570            .and_then(|b| b.get().target());
571        let pull_id = self
572            .repo
573            .find_branch(
574                &format!("{}/{}", self.pull_remote(), name),
575                git2::BranchType::Remote,
576            )
577            .ok()
578            .and_then(|b| b.get().target());
579
580        Ok(Branch {
581            remote: None,
582            name: name.to_owned(),
583            id,
584            push_id,
585            pull_id,
586        })
587    }
588
589    fn load_remote_branch(
590        &self,
591        branch: &git2::Branch<'_>,
592        remote: &str,
593        name: &str,
594    ) -> Result<Branch, git2::Error> {
595        let id = branch.get().target().unwrap();
596
597        let push_id = (remote == self.push_remote()).then_some(id);
598        let pull_id = (remote == self.pull_remote()).then_some(id);
599
600        Ok(Branch {
601            remote: Some(remote.to_owned()),
602            name: name.to_owned(),
603            id,
604            push_id,
605            pull_id,
606        })
607    }
608
609    pub fn detach(&mut self) -> Result<(), git2::Error> {
610        let head_id = self
611            .repo
612            .head()
613            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
614            .resolve()
615            .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
616            .target()
617            .unwrap();
618        self.repo.set_head_detached(head_id)?;
619        Ok(())
620    }
621
622    pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
623        // HACK: We shouldn't limit ourselves to `Local`
624        let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
625        self.repo.set_head(branch.get().name().unwrap())?;
626        let mut builder = git2::build::CheckoutBuilder::new();
627        builder.force();
628        self.repo.checkout_head(Some(&mut builder))?;
629        Ok(())
630    }
631
632    fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
633        let mut interned_strings = self.interned_strings.borrow_mut();
634        if let Some(interned) = interned_strings.get(data) {
635            std::rc::Rc::clone(interned)
636        } else {
637            let interned = std::rc::Rc::from(data);
638            interned_strings.insert(std::rc::Rc::clone(&interned));
639            interned
640        }
641    }
642}
643
644impl std::fmt::Debug for GitRepo {
645    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
646        f.debug_struct("GitRepo")
647            .field("repo", &self.repo.workdir())
648            .field("push_remote", &self.push_remote.as_deref())
649            .field("pull_remote", &self.pull_remote.as_deref())
650            .finish()
651    }
652}
653
654impl Repo for GitRepo {
655    fn path(&self) -> Option<&std::path::Path> {
656        Some(self.repo.path())
657    }
658    fn user(&self) -> Option<std::rc::Rc<str>> {
659        self.user()
660    }
661
662    fn is_dirty(&self) -> bool {
663        self.is_dirty()
664    }
665
666    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
667        self.merge_base(one, two)
668    }
669
670    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
671        self.find_commit(id)
672    }
673
674    fn head_commit(&self) -> std::rc::Rc<Commit> {
675        self.head_commit()
676    }
677
678    fn head_branch(&self) -> Option<Branch> {
679        self.head_branch()
680    }
681
682    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
683        self.resolve(revspec)
684    }
685
686    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
687        self.parent_ids(head_id)
688    }
689
690    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
691        self.commit_count(base_id, head_id)
692    }
693
694    fn commit_range(
695        &self,
696        base_bound: std::ops::Bound<&git2::Oid>,
697        head_bound: std::ops::Bound<&git2::Oid>,
698    ) -> Result<Vec<git2::Oid>, git2::Error> {
699        self.commit_range(base_bound, head_bound)
700    }
701
702    fn contains_commit(
703        &self,
704        haystack_id: git2::Oid,
705        needle_id: git2::Oid,
706    ) -> Result<bool, git2::Error> {
707        self.contains_commit(haystack_id, needle_id)
708    }
709
710    fn cherry_pick(
711        &mut self,
712        head_id: git2::Oid,
713        cherry_id: git2::Oid,
714    ) -> Result<git2::Oid, git2::Error> {
715        self.cherry_pick(head_id, cherry_id)
716    }
717
718    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
719        self.squash(head_id, into_id)
720    }
721
722    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
723        self.stash_push(message)
724    }
725
726    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
727        self.stash_pop(stash_id)
728    }
729
730    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
731        self.branch(name, id)
732    }
733
734    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
735        self.delete_branch(name)
736    }
737
738    fn find_local_branch(&self, name: &str) -> Option<Branch> {
739        self.find_local_branch(name)
740    }
741
742    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
743        self.find_remote_branch(remote, name)
744    }
745
746    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
747        Box::new(self.local_branches())
748    }
749
750    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
751        Box::new(self.remote_branches())
752    }
753
754    fn detach(&mut self) -> Result<(), git2::Error> {
755        self.detach()
756    }
757
758    fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
759        self.switch(name)
760    }
761}
762
763#[derive(Debug)]
764pub struct InMemoryRepo {
765    commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
766    branches: std::collections::HashMap<String, Branch>,
767    head_id: Option<git2::Oid>,
768
769    last_id: std::sync::atomic::AtomicUsize,
770}
771
772impl InMemoryRepo {
773    pub fn new() -> Self {
774        Self {
775            commits: Default::default(),
776            branches: Default::default(),
777            head_id: Default::default(),
778            last_id: std::sync::atomic::AtomicUsize::new(1),
779        }
780    }
781
782    pub fn clear(&mut self) {
783        *self = InMemoryRepo::new();
784    }
785
786    pub fn gen_id(&mut self) -> git2::Oid {
787        let last_id = self
788            .last_id
789            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
790        let sha = format!("{last_id:040x}");
791        git2::Oid::from_str(&sha).unwrap()
792    }
793
794    pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
795        if let Some(parent_id) = parent_id {
796            assert!(self.commits.contains_key(&parent_id));
797        }
798        self.head_id = Some(commit.id);
799        self.commits
800            .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
801    }
802
803    pub fn head_id(&mut self) -> Option<git2::Oid> {
804        self.head_id
805    }
806
807    pub fn set_head(&mut self, head_id: git2::Oid) {
808        assert!(self.commits.contains_key(&head_id));
809        self.head_id = Some(head_id);
810    }
811
812    pub fn mark_branch(&mut self, branch: Branch) {
813        assert!(self.commits.contains_key(&branch.id));
814        self.branches.insert(branch.name.clone(), branch);
815    }
816
817    fn user(&self) -> Option<std::rc::Rc<str>> {
818        None
819    }
820
821    pub fn is_dirty(&self) -> bool {
822        false
823    }
824
825    pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
826        let one_ancestors: Vec<_> = self.commits_from(one).collect();
827        self.commits_from(two)
828            .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
829            .map(|c| c.id)
830            .next()
831    }
832
833    pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
834        self.commits.get(&id).map(|c| c.1.clone())
835    }
836
837    pub fn head_commit(&self) -> std::rc::Rc<Commit> {
838        self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
839    }
840
841    pub fn head_branch(&self) -> Option<Branch> {
842        self.branches
843            .values()
844            .find(|b| b.id == self.head_id.unwrap())
845            .cloned()
846    }
847
848    pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
849        let branch = self.branches.get(revspec)?;
850        self.find_commit(branch.id)
851    }
852
853    pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
854        let next = self
855            .commits
856            .get(&head_id)
857            .and_then(|(parent, _commit)| *parent);
858        Ok(next.into_iter().collect())
859    }
860
861    fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
862        let next = self.commits.get(&head_id).cloned();
863        CommitsFrom {
864            commits: &self.commits,
865            next,
866        }
867    }
868
869    pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
870        let merge_base_id = self.merge_base(base_id, head_id)?;
871        let count = self
872            .commits_from(head_id)
873            .take_while(move |cur_id| cur_id.id != merge_base_id)
874            .count();
875        Some(count)
876    }
877
878    pub fn commit_range(
879        &self,
880        base_bound: std::ops::Bound<&git2::Oid>,
881        head_bound: std::ops::Bound<&git2::Oid>,
882    ) -> Result<Vec<git2::Oid>, git2::Error> {
883        let head_id = match head_bound {
884            std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
885            std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
886        };
887        let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
888            0
889        } else {
890            1
891        };
892
893        let base_id = match base_bound {
894            std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
895                debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
896                Some(*base_id)
897            }
898            std::ops::Bound::Unbounded => None,
899        };
900
901        let mut result = self
902            .commits_from(head_id)
903            .skip(skip)
904            .map(|commit| commit.id)
905            .take_while(|id| Some(*id) != base_id)
906            .collect::<Vec<_>>();
907        if let std::ops::Bound::Included(base_id) = base_bound {
908            result.push(*base_id);
909        }
910        Ok(result)
911    }
912
913    pub fn contains_commit(
914        &self,
915        haystack_id: git2::Oid,
916        needle_id: git2::Oid,
917    ) -> Result<bool, git2::Error> {
918        // Because we don't have the information for likeness matches, just checking for Oid
919        let mut next = Some(haystack_id);
920        while let Some(current) = next {
921            if current == needle_id {
922                return Ok(true);
923            }
924            next = self.commits.get(&current).and_then(|c| c.0);
925        }
926        Ok(false)
927    }
928
929    pub fn cherry_pick(
930        &mut self,
931        head_id: git2::Oid,
932        cherry_id: git2::Oid,
933    ) -> Result<git2::Oid, git2::Error> {
934        let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
935            git2::Error::new(
936                git2::ErrorCode::NotFound,
937                git2::ErrorClass::Reference,
938                format!("could not find commit {cherry_id:?}"),
939            )
940        })?;
941        let mut cherry_commit = Commit::clone(&cherry_commit);
942        let new_id = self.gen_id();
943        cherry_commit.id = new_id;
944        self.commits
945            .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
946        Ok(new_id)
947    }
948
949    pub fn squash(
950        &mut self,
951        head_id: git2::Oid,
952        into_id: git2::Oid,
953    ) -> Result<git2::Oid, git2::Error> {
954        self.commits.get(&head_id).cloned().ok_or_else(|| {
955            git2::Error::new(
956                git2::ErrorCode::NotFound,
957                git2::ErrorClass::Reference,
958                format!("could not find commit {head_id:?}"),
959            )
960        })?;
961        let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
962            git2::Error::new(
963                git2::ErrorCode::NotFound,
964                git2::ErrorClass::Reference,
965                format!("could not find commit {into_id:?}"),
966            )
967        })?;
968        let intos_parent = intos_parent.unwrap();
969
970        let mut squashed_commit = Commit::clone(&into_commit);
971        let new_id = self.gen_id();
972        squashed_commit.id = new_id;
973        self.commits.insert(
974            new_id,
975            (Some(intos_parent), std::rc::Rc::new(squashed_commit)),
976        );
977        Ok(new_id)
978    }
979
980    pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid, git2::Error> {
981        Err(git2::Error::new(
982            git2::ErrorCode::NotFound,
983            git2::ErrorClass::Reference,
984            "stash is unsupported",
985        ))
986    }
987
988    pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<(), git2::Error> {
989        Err(git2::Error::new(
990            git2::ErrorCode::NotFound,
991            git2::ErrorClass::Reference,
992            "stash is unsupported",
993        ))
994    }
995
996    pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
997        self.branches.insert(
998            name.to_owned(),
999            Branch {
1000                remote: None,
1001                name: name.to_owned(),
1002                id,
1003                push_id: None,
1004                pull_id: None,
1005            },
1006        );
1007        Ok(())
1008    }
1009
1010    pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1011        self.branches.remove(name).map(|_| ()).ok_or_else(|| {
1012            git2::Error::new(
1013                git2::ErrorCode::NotFound,
1014                git2::ErrorClass::Reference,
1015                format!("could not remove branch {name:?}"),
1016            )
1017        })
1018    }
1019
1020    pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1021        self.branches.get(name).cloned()
1022    }
1023
1024    pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1025        None
1026    }
1027
1028    pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1029        self.branches.values().cloned()
1030    }
1031
1032    pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1033        None.into_iter()
1034    }
1035
1036    pub fn detach(&mut self) -> Result<(), git2::Error> {
1037        Ok(())
1038    }
1039
1040    pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1041        let branch = self.find_local_branch(name).ok_or_else(|| {
1042            git2::Error::new(
1043                git2::ErrorCode::NotFound,
1044                git2::ErrorClass::Reference,
1045                format!("could not find branch {name:?}"),
1046            )
1047        })?;
1048        self.head_id = Some(branch.id);
1049        Ok(())
1050    }
1051}
1052
1053impl Default for InMemoryRepo {
1054    fn default() -> Self {
1055        Self::new()
1056    }
1057}
1058
1059struct CommitsFrom<'c> {
1060    commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1061    next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1062}
1063
1064impl Iterator for CommitsFrom<'_> {
1065    type Item = std::rc::Rc<Commit>;
1066
1067    fn next(&mut self) -> Option<Self::Item> {
1068        let mut current = None;
1069        std::mem::swap(&mut current, &mut self.next);
1070        let current = current?;
1071        if let Some(parent_id) = current.0 {
1072            self.next = self.commits.get(&parent_id).cloned();
1073        }
1074        Some(current.1)
1075    }
1076}
1077
1078impl Repo for InMemoryRepo {
1079    fn path(&self) -> Option<&std::path::Path> {
1080        None
1081    }
1082    fn user(&self) -> Option<std::rc::Rc<str>> {
1083        self.user()
1084    }
1085
1086    fn is_dirty(&self) -> bool {
1087        self.is_dirty()
1088    }
1089
1090    fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1091        self.merge_base(one, two)
1092    }
1093
1094    fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1095        self.find_commit(id)
1096    }
1097
1098    fn head_commit(&self) -> std::rc::Rc<Commit> {
1099        self.head_commit()
1100    }
1101
1102    fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1103        self.resolve(revspec)
1104    }
1105
1106    fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
1107        self.parent_ids(head_id)
1108    }
1109
1110    fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1111        self.commit_count(base_id, head_id)
1112    }
1113
1114    fn commit_range(
1115        &self,
1116        base_bound: std::ops::Bound<&git2::Oid>,
1117        head_bound: std::ops::Bound<&git2::Oid>,
1118    ) -> Result<Vec<git2::Oid>, git2::Error> {
1119        self.commit_range(base_bound, head_bound)
1120    }
1121
1122    fn contains_commit(
1123        &self,
1124        haystack_id: git2::Oid,
1125        needle_id: git2::Oid,
1126    ) -> Result<bool, git2::Error> {
1127        self.contains_commit(haystack_id, needle_id)
1128    }
1129
1130    fn cherry_pick(
1131        &mut self,
1132        head_id: git2::Oid,
1133        cherry_id: git2::Oid,
1134    ) -> Result<git2::Oid, git2::Error> {
1135        self.cherry_pick(head_id, cherry_id)
1136    }
1137
1138    fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
1139        self.squash(head_id, into_id)
1140    }
1141
1142    fn head_branch(&self) -> Option<Branch> {
1143        self.head_branch()
1144    }
1145
1146    fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
1147        self.stash_push(message)
1148    }
1149
1150    fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
1151        self.stash_pop(stash_id)
1152    }
1153
1154    fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
1155        self.branch(name, id)
1156    }
1157
1158    fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1159        self.delete_branch(name)
1160    }
1161
1162    fn find_local_branch(&self, name: &str) -> Option<Branch> {
1163        self.find_local_branch(name)
1164    }
1165
1166    fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1167        self.find_remote_branch(remote, name)
1168    }
1169
1170    fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1171        Box::new(self.local_branches())
1172    }
1173
1174    fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1175        Box::new(self.remote_branches())
1176    }
1177
1178    fn detach(&mut self) -> Result<(), git2::Error> {
1179        self.detach()
1180    }
1181
1182    fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1183        self.switch(name)
1184    }
1185}
1186
1187pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1188    let branch = repo.head_branch();
1189    let stash_msg = format!(
1190        "WIP on {} ({})",
1191        branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1192        context
1193    );
1194    match repo.stash_push(Some(&stash_msg)) {
1195        Ok(stash_id) => {
1196            log::info!(
1197                "Saved working directory and index state {}: {}",
1198                stash_msg,
1199                stash_id
1200            );
1201            Some(stash_id)
1202        }
1203        Err(err) => {
1204            log::debug!("Failed to stash: {}", err);
1205            None
1206        }
1207    }
1208}
1209
1210pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1211    if let Some(stash_id) = stash_id {
1212        match repo.stash_pop(stash_id) {
1213            Ok(()) => {
1214                log::info!("Dropped refs/stash {}", stash_id);
1215            }
1216            Err(err) => {
1217                log::error!("Failed to pop {} from stash: {}", stash_id, err);
1218            }
1219        }
1220    }
1221}
1222
1223pub fn commit_range(
1224    repo: &dyn Repo,
1225    head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1226) -> Result<Vec<git2::Oid>, git2::Error> {
1227    let head_bound = head_to_base.start_bound();
1228    let base_bound = head_to_base.end_bound();
1229    repo.commit_range(base_bound, head_bound)
1230}