1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub trait Repo {
5 fn path(&self) -> Option<&std::path::Path>;
6 fn user(&self) -> Option<std::rc::Rc<str>>;
7
8 fn is_dirty(&self) -> bool;
9 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
10
11 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
12 fn head_commit(&self) -> std::rc::Rc<Commit>;
13 fn head_branch(&self) -> Option<Branch>;
14 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
15 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error>;
16 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
17 fn commit_range(
18 &self,
19 base_bound: std::ops::Bound<&git2::Oid>,
20 head_bound: std::ops::Bound<&git2::Oid>,
21 ) -> Result<Vec<git2::Oid>, git2::Error>;
22 fn contains_commit(
23 &self,
24 haystack_id: git2::Oid,
25 needle_id: git2::Oid,
26 ) -> Result<bool, git2::Error>;
27 fn cherry_pick(
28 &mut self,
29 head_id: git2::Oid,
30 cherry_id: git2::Oid,
31 ) -> Result<git2::Oid, git2::Error>;
32 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error>;
33
34 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error>;
35 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error>;
36
37 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error>;
38 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error>;
39 fn find_local_branch(&self, name: &str) -> Option<Branch>;
40 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
41 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
42 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
43 fn detach(&mut self) -> Result<(), git2::Error>;
44 fn switch(&mut self, name: &str) -> Result<(), git2::Error>;
45}
46
47#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct Branch {
49 pub remote: Option<String>,
50 pub name: String,
51 pub id: git2::Oid,
52 pub push_id: Option<git2::Oid>,
53 pub pull_id: Option<git2::Oid>,
54}
55
56impl Branch {
57 pub fn local_name(&self) -> Option<&str> {
58 self.remote.is_none().then_some(self.name.as_str())
59 }
60}
61
62impl std::fmt::Display for Branch {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 if let Some(remote) = self.remote.as_deref() {
65 write!(f, "{}/{}", remote, self.name.as_str())
66 } else {
67 write!(f, "{}", self.name.as_str())
68 }
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct Commit {
74 pub id: git2::Oid,
75 pub tree_id: git2::Oid,
76 pub summary: bstr::BString,
77 pub time: std::time::SystemTime,
78 pub author: Option<std::rc::Rc<str>>,
79 pub committer: Option<std::rc::Rc<str>>,
80}
81
82impl Commit {
83 pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
84 self.summary
85 .strip_prefix(b"fixup! ")
86 .map(ByteSlice::as_bstr)
87 }
88
89 pub fn wip_summary(&self) -> Option<&bstr::BStr> {
90 static WIP_PREFIXES: &[&[u8]] = &[
92 b"WIP:", b"draft:", b"Draft:", b"wip ", b"WIP ", ];
95
96 if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
97 Some(b"".as_bstr())
99 } else {
100 WIP_PREFIXES.iter().find_map(|prefix| {
101 self.summary
102 .strip_prefix(*prefix)
103 .map(ByteSlice::trim)
104 .map(ByteSlice::as_bstr)
105 })
106 }
107 }
108
109 pub fn revert_summary(&self) -> Option<&bstr::BStr> {
110 self.summary
111 .strip_prefix(b"Revert ")
112 .and_then(|s| s.strip_suffix(b"\""))
113 .map(ByteSlice::as_bstr)
114 }
115}
116
117pub struct GitRepo {
118 repo: git2::Repository,
119 sign: Option<git2_ext::ops::UserSign>,
120 push_remote: Option<String>,
121 pull_remote: Option<String>,
122 commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
123 interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
124 bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
125 counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
126}
127
128impl GitRepo {
129 pub fn new(repo: git2::Repository) -> Self {
130 Self {
131 repo,
132 sign: None,
133 push_remote: None,
134 pull_remote: None,
135 commits: Default::default(),
136 interned_strings: Default::default(),
137 bases: Default::default(),
138 counts: Default::default(),
139 }
140 }
141
142 pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
143 if yes {
144 let config = self.repo.config()?;
145 let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
146 self.sign = Some(sign);
147 } else {
148 self.sign = None;
149 }
150 Ok(())
151 }
152
153 pub fn set_push_remote(&mut self, remote: &str) {
154 self.push_remote = Some(remote.to_owned());
155 }
156
157 pub fn set_pull_remote(&mut self, remote: &str) {
158 self.pull_remote = Some(remote.to_owned());
159 }
160
161 pub fn push_remote(&self) -> &str {
162 self.push_remote.as_deref().unwrap_or("origin")
163 }
164
165 pub fn pull_remote(&self) -> &str {
166 self.pull_remote.as_deref().unwrap_or("origin")
167 }
168
169 pub fn raw(&self) -> &git2::Repository {
170 &self.repo
171 }
172
173 pub fn user(&self) -> Option<std::rc::Rc<str>> {
174 self.repo
175 .signature()
176 .ok()
177 .and_then(|s| s.name().map(|n| self.intern_string(n)))
178 }
179
180 pub fn is_dirty(&self) -> bool {
181 if self.repo.state() != git2::RepositoryState::Clean {
182 log::trace!("Repository status is unclean: {:?}", self.repo.state());
183 return true;
184 }
185
186 let status = self
187 .repo
188 .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
189 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
190 if status.is_empty() {
191 false
192 } else {
193 log::trace!(
194 "Repository is dirty: {}",
195 status
196 .iter()
197 .filter_map(|s| s.path().map(|s| s.to_owned()))
198 .join(", ")
199 );
200 true
201 }
202 }
203
204 pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
205 if one == two {
206 return Some(one);
207 }
208
209 let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
210 *self
211 .bases
212 .borrow_mut()
213 .entry((smaller, larger))
214 .or_insert_with(|| self.merge_base_raw(smaller, larger))
215 }
216
217 fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
218 self.repo.merge_base(one, two).ok()
219 }
220
221 pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
222 let mut commits = self.commits.borrow_mut();
223 if let Some(commit) = commits.get(&id) {
224 Some(std::rc::Rc::clone(commit))
225 } else {
226 let commit = self.repo.find_commit(id).ok()?;
227 let summary: bstr::BString = commit.summary_bytes().unwrap().into();
228 let time = std::time::SystemTime::UNIX_EPOCH
229 + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
230
231 let author = commit.author().name().map(|n| self.intern_string(n));
232 let committer = commit.author().name().map(|n| self.intern_string(n));
233 let commit = std::rc::Rc::new(Commit {
234 id: commit.id(),
235 tree_id: commit.tree_id(),
236 summary,
237 time,
238 author,
239 committer,
240 });
241 commits.insert(id, std::rc::Rc::clone(&commit));
242 Some(commit)
243 }
244 }
245
246 pub fn head_commit(&self) -> std::rc::Rc<Commit> {
247 let head_id = self
248 .repo
249 .head()
250 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
251 .resolve()
252 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
253 .target()
254 .unwrap();
255 self.find_commit(head_id).unwrap()
256 }
257
258 pub fn head_branch(&self) -> Option<Branch> {
259 let resolved = self
260 .repo
261 .head()
262 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
263 .resolve()
264 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
265 let name = resolved.shorthand()?;
266 let id = resolved.target()?;
267
268 let push_id = self
269 .repo
270 .find_branch(
271 &format!("{}/{}", self.push_remote(), name),
272 git2::BranchType::Remote,
273 )
274 .ok()
275 .and_then(|b| b.get().target());
276 let pull_id = self
277 .repo
278 .find_branch(
279 &format!("{}/{}", self.pull_remote(), name),
280 git2::BranchType::Remote,
281 )
282 .ok()
283 .and_then(|b| b.get().target());
284
285 Some(Branch {
286 remote: None,
287 name: name.to_owned(),
288 id,
289 push_id,
290 pull_id,
291 })
292 }
293
294 pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
295 let id = self.repo.revparse_single(revspec).ok()?.id();
296 self.find_commit(id)
297 }
298
299 pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
300 let commit = self.repo.find_commit(head_id)?;
301 Ok(commit.parent_ids().collect())
302 }
303
304 pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
305 if base_id == head_id {
306 return Some(0);
307 }
308
309 *self
310 .counts
311 .borrow_mut()
312 .entry((base_id, head_id))
313 .or_insert_with(|| self.commit_count_raw(base_id, head_id))
314 }
315
316 fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
317 let merge_base_id = self.merge_base(base_id, head_id)?;
318 if merge_base_id != base_id {
319 return None;
320 }
321 let mut revwalk = self
322 .repo
323 .revwalk()
324 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
325 revwalk
326 .push(head_id)
327 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
328 revwalk
329 .hide(base_id)
330 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
331 Some(revwalk.count())
332 }
333
334 pub fn commit_range(
335 &self,
336 base_bound: std::ops::Bound<&git2::Oid>,
337 head_bound: std::ops::Bound<&git2::Oid>,
338 ) -> Result<Vec<git2::Oid>, git2::Error> {
339 let head_id = match head_bound {
340 std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
341 std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
342 };
343 let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
344 0
345 } else {
346 1
347 };
348
349 let base_id = match base_bound {
350 std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
351 debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
352 Some(*base_id)
353 }
354 std::ops::Bound::Unbounded => None,
355 };
356
357 let mut revwalk = self.repo.revwalk()?;
358 revwalk.push(head_id)?;
359 if let Some(base_id) = base_id {
360 revwalk.hide(base_id)?;
361 }
362 revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
363 let mut result = revwalk
364 .filter_map(Result::ok)
365 .skip(skip)
366 .take_while(|id| Some(*id) != base_id)
367 .collect::<Vec<_>>();
368 if let std::ops::Bound::Included(base_id) = base_bound {
369 result.push(*base_id);
370 }
371 Ok(result)
372 }
373
374 pub fn contains_commit(
375 &self,
376 haystack_id: git2::Oid,
377 needle_id: git2::Oid,
378 ) -> Result<bool, git2::Error> {
379 let needle_commit = self.repo.find_commit(needle_id)?;
380 let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
381 let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
382
383 let parent_ann_commit = if 0 < needle_commit.parent_count() {
384 let parent_commit = needle_commit.parent(0)?;
385 Some(self.repo.find_annotated_commit(parent_commit.id())?)
386 } else {
387 None
388 };
389
390 let mut rebase = self.repo.rebase(
391 Some(&needle_ann_commit),
392 parent_ann_commit.as_ref(),
393 Some(&haystack_ann_commit),
394 Some(git2::RebaseOptions::new().inmemory(true)),
395 )?;
396
397 if let Some(op) = rebase.next() {
398 op.map_err(|e| {
399 let _ = rebase.abort();
400 e
401 })?;
402 let inmemory_index = rebase
403 .inmemory_index()
404 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
405 if inmemory_index.has_conflicts() {
406 return Ok(false);
407 }
408
409 let sig = self
410 .repo
411 .signature()
412 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
413 let result = rebase.commit(None, &sig, None).map_err(|e| {
414 let _ = rebase.abort();
415 e
416 });
417 match result {
418 Ok(_) => Ok(false),
420 Err(err) => {
421 if err.class() == git2::ErrorClass::Rebase
422 && err.code() == git2::ErrorCode::Applied
423 {
424 return Ok(true);
425 }
426 Err(err)
427 }
428 }
429 } else {
430 rebase.finish(None)?;
432 Ok(true)
433 }
434 }
435
436 fn cherry_pick(
437 &mut self,
438 head_id: git2::Oid,
439 cherry_id: git2::Oid,
440 ) -> Result<git2::Oid, git2::Error> {
441 git2_ext::ops::cherry_pick(
442 &self.repo,
443 head_id,
444 cherry_id,
445 self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
446 )
447 }
448
449 pub fn squash(
450 &mut self,
451 head_id: git2::Oid,
452 into_id: git2::Oid,
453 ) -> Result<git2::Oid, git2::Error> {
454 git2_ext::ops::squash(
455 &self.repo,
456 head_id,
457 into_id,
458 self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
459 )
460 }
461
462 pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
463 let signature = self.repo.signature()?;
464 self.repo.stash_save2(&signature, message, None)
465 }
466
467 pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
468 let mut index = None;
469 self.repo.stash_foreach(|i, _, id| {
470 if *id == stash_id {
471 index = Some(i);
472 false
473 } else {
474 true
475 }
476 })?;
477 let index = index.ok_or_else(|| {
478 git2::Error::new(
479 git2::ErrorCode::NotFound,
480 git2::ErrorClass::Reference,
481 "stash ID not found",
482 )
483 })?;
484 self.repo.stash_pop(index, None)
485 }
486
487 pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
488 let commit = self.repo.find_commit(id)?;
489 self.repo.branch(name, &commit, true)?;
490 Ok(())
491 }
492
493 pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
494 let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
496 branch.delete()
497 }
498
499 pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
500 let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
501 self.load_local_branch(&branch, name).ok()
502 }
503
504 pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
505 let qualified = format!("{remote}/{name}");
506 let branch = self
507 .repo
508 .find_branch(&qualified, git2::BranchType::Remote)
509 .ok()?;
510 self.load_remote_branch(&branch, remote, name).ok()
511 }
512
513 pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
514 log::trace!("Loading local branches");
515 self.repo
516 .branches(Some(git2::BranchType::Local))
517 .into_iter()
518 .flatten()
519 .filter_map(move |branch| {
520 let (branch, _) = branch.ok()?;
521 let name = if let Some(name) = branch.name().ok().flatten() {
522 name
523 } else {
524 log::debug!(
525 "Ignoring non-UTF8 branch {:?}",
526 branch.name_bytes().unwrap().as_bstr()
527 );
528 return None;
529 };
530 self.load_local_branch(&branch, name).ok()
531 })
532 }
533
534 pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
535 log::trace!("Loading remote branches");
536 self.repo
537 .branches(Some(git2::BranchType::Remote))
538 .into_iter()
539 .flatten()
540 .filter_map(move |branch| {
541 let (branch, _) = branch.ok()?;
542 let name = if let Some(name) = branch.name().ok().flatten() {
543 name
544 } else {
545 log::debug!(
546 "Ignoring non-UTF8 branch {:?}",
547 branch.name_bytes().unwrap().as_bstr()
548 );
549 return None;
550 };
551 let (remote, name) = name.split_once('/').unwrap();
552 self.load_remote_branch(&branch, remote, name).ok()
553 })
554 }
555
556 fn load_local_branch(
557 &self,
558 branch: &git2::Branch<'_>,
559 name: &str,
560 ) -> Result<Branch, git2::Error> {
561 let id = branch.get().target().unwrap();
562
563 let push_id = self
564 .repo
565 .find_branch(
566 &format!("{}/{}", self.push_remote(), name),
567 git2::BranchType::Remote,
568 )
569 .ok()
570 .and_then(|b| b.get().target());
571 let pull_id = self
572 .repo
573 .find_branch(
574 &format!("{}/{}", self.pull_remote(), name),
575 git2::BranchType::Remote,
576 )
577 .ok()
578 .and_then(|b| b.get().target());
579
580 Ok(Branch {
581 remote: None,
582 name: name.to_owned(),
583 id,
584 push_id,
585 pull_id,
586 })
587 }
588
589 fn load_remote_branch(
590 &self,
591 branch: &git2::Branch<'_>,
592 remote: &str,
593 name: &str,
594 ) -> Result<Branch, git2::Error> {
595 let id = branch.get().target().unwrap();
596
597 let push_id = (remote == self.push_remote()).then_some(id);
598 let pull_id = (remote == self.pull_remote()).then_some(id);
599
600 Ok(Branch {
601 remote: Some(remote.to_owned()),
602 name: name.to_owned(),
603 id,
604 push_id,
605 pull_id,
606 })
607 }
608
609 pub fn detach(&mut self) -> Result<(), git2::Error> {
610 let head_id = self
611 .repo
612 .head()
613 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
614 .resolve()
615 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
616 .target()
617 .unwrap();
618 self.repo.set_head_detached(head_id)?;
619 Ok(())
620 }
621
622 pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
623 let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
625 self.repo.set_head(branch.get().name().unwrap())?;
626 let mut builder = git2::build::CheckoutBuilder::new();
627 builder.force();
628 self.repo.checkout_head(Some(&mut builder))?;
629 Ok(())
630 }
631
632 fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
633 let mut interned_strings = self.interned_strings.borrow_mut();
634 if let Some(interned) = interned_strings.get(data) {
635 std::rc::Rc::clone(interned)
636 } else {
637 let interned = std::rc::Rc::from(data);
638 interned_strings.insert(std::rc::Rc::clone(&interned));
639 interned
640 }
641 }
642}
643
644impl std::fmt::Debug for GitRepo {
645 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
646 f.debug_struct("GitRepo")
647 .field("repo", &self.repo.workdir())
648 .field("push_remote", &self.push_remote.as_deref())
649 .field("pull_remote", &self.pull_remote.as_deref())
650 .finish()
651 }
652}
653
654impl Repo for GitRepo {
655 fn path(&self) -> Option<&std::path::Path> {
656 Some(self.repo.path())
657 }
658 fn user(&self) -> Option<std::rc::Rc<str>> {
659 self.user()
660 }
661
662 fn is_dirty(&self) -> bool {
663 self.is_dirty()
664 }
665
666 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
667 self.merge_base(one, two)
668 }
669
670 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
671 self.find_commit(id)
672 }
673
674 fn head_commit(&self) -> std::rc::Rc<Commit> {
675 self.head_commit()
676 }
677
678 fn head_branch(&self) -> Option<Branch> {
679 self.head_branch()
680 }
681
682 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
683 self.resolve(revspec)
684 }
685
686 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
687 self.parent_ids(head_id)
688 }
689
690 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
691 self.commit_count(base_id, head_id)
692 }
693
694 fn commit_range(
695 &self,
696 base_bound: std::ops::Bound<&git2::Oid>,
697 head_bound: std::ops::Bound<&git2::Oid>,
698 ) -> Result<Vec<git2::Oid>, git2::Error> {
699 self.commit_range(base_bound, head_bound)
700 }
701
702 fn contains_commit(
703 &self,
704 haystack_id: git2::Oid,
705 needle_id: git2::Oid,
706 ) -> Result<bool, git2::Error> {
707 self.contains_commit(haystack_id, needle_id)
708 }
709
710 fn cherry_pick(
711 &mut self,
712 head_id: git2::Oid,
713 cherry_id: git2::Oid,
714 ) -> Result<git2::Oid, git2::Error> {
715 self.cherry_pick(head_id, cherry_id)
716 }
717
718 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
719 self.squash(head_id, into_id)
720 }
721
722 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
723 self.stash_push(message)
724 }
725
726 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
727 self.stash_pop(stash_id)
728 }
729
730 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
731 self.branch(name, id)
732 }
733
734 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
735 self.delete_branch(name)
736 }
737
738 fn find_local_branch(&self, name: &str) -> Option<Branch> {
739 self.find_local_branch(name)
740 }
741
742 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
743 self.find_remote_branch(remote, name)
744 }
745
746 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
747 Box::new(self.local_branches())
748 }
749
750 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
751 Box::new(self.remote_branches())
752 }
753
754 fn detach(&mut self) -> Result<(), git2::Error> {
755 self.detach()
756 }
757
758 fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
759 self.switch(name)
760 }
761}
762
763#[derive(Debug)]
764pub struct InMemoryRepo {
765 commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
766 branches: std::collections::HashMap<String, Branch>,
767 head_id: Option<git2::Oid>,
768
769 last_id: std::sync::atomic::AtomicUsize,
770}
771
772impl InMemoryRepo {
773 pub fn new() -> Self {
774 Self {
775 commits: Default::default(),
776 branches: Default::default(),
777 head_id: Default::default(),
778 last_id: std::sync::atomic::AtomicUsize::new(1),
779 }
780 }
781
782 pub fn clear(&mut self) {
783 *self = InMemoryRepo::new();
784 }
785
786 pub fn gen_id(&mut self) -> git2::Oid {
787 let last_id = self
788 .last_id
789 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
790 let sha = format!("{last_id:040x}");
791 git2::Oid::from_str(&sha).unwrap()
792 }
793
794 pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
795 if let Some(parent_id) = parent_id {
796 assert!(self.commits.contains_key(&parent_id));
797 }
798 self.head_id = Some(commit.id);
799 self.commits
800 .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
801 }
802
803 pub fn head_id(&mut self) -> Option<git2::Oid> {
804 self.head_id
805 }
806
807 pub fn set_head(&mut self, head_id: git2::Oid) {
808 assert!(self.commits.contains_key(&head_id));
809 self.head_id = Some(head_id);
810 }
811
812 pub fn mark_branch(&mut self, branch: Branch) {
813 assert!(self.commits.contains_key(&branch.id));
814 self.branches.insert(branch.name.clone(), branch);
815 }
816
817 fn user(&self) -> Option<std::rc::Rc<str>> {
818 None
819 }
820
821 pub fn is_dirty(&self) -> bool {
822 false
823 }
824
825 pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
826 let one_ancestors: Vec<_> = self.commits_from(one).collect();
827 self.commits_from(two)
828 .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
829 .map(|c| c.id)
830 .next()
831 }
832
833 pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
834 self.commits.get(&id).map(|c| c.1.clone())
835 }
836
837 pub fn head_commit(&self) -> std::rc::Rc<Commit> {
838 self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
839 }
840
841 pub fn head_branch(&self) -> Option<Branch> {
842 self.branches
843 .values()
844 .find(|b| b.id == self.head_id.unwrap())
845 .cloned()
846 }
847
848 pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
849 let branch = self.branches.get(revspec)?;
850 self.find_commit(branch.id)
851 }
852
853 pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
854 let next = self
855 .commits
856 .get(&head_id)
857 .and_then(|(parent, _commit)| *parent);
858 Ok(next.into_iter().collect())
859 }
860
861 fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
862 let next = self.commits.get(&head_id).cloned();
863 CommitsFrom {
864 commits: &self.commits,
865 next,
866 }
867 }
868
869 pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
870 let merge_base_id = self.merge_base(base_id, head_id)?;
871 let count = self
872 .commits_from(head_id)
873 .take_while(move |cur_id| cur_id.id != merge_base_id)
874 .count();
875 Some(count)
876 }
877
878 pub fn commit_range(
879 &self,
880 base_bound: std::ops::Bound<&git2::Oid>,
881 head_bound: std::ops::Bound<&git2::Oid>,
882 ) -> Result<Vec<git2::Oid>, git2::Error> {
883 let head_id = match head_bound {
884 std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
885 std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
886 };
887 let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
888 0
889 } else {
890 1
891 };
892
893 let base_id = match base_bound {
894 std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
895 debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
896 Some(*base_id)
897 }
898 std::ops::Bound::Unbounded => None,
899 };
900
901 let mut result = self
902 .commits_from(head_id)
903 .skip(skip)
904 .map(|commit| commit.id)
905 .take_while(|id| Some(*id) != base_id)
906 .collect::<Vec<_>>();
907 if let std::ops::Bound::Included(base_id) = base_bound {
908 result.push(*base_id);
909 }
910 Ok(result)
911 }
912
913 pub fn contains_commit(
914 &self,
915 haystack_id: git2::Oid,
916 needle_id: git2::Oid,
917 ) -> Result<bool, git2::Error> {
918 let mut next = Some(haystack_id);
920 while let Some(current) = next {
921 if current == needle_id {
922 return Ok(true);
923 }
924 next = self.commits.get(¤t).and_then(|c| c.0);
925 }
926 Ok(false)
927 }
928
929 pub fn cherry_pick(
930 &mut self,
931 head_id: git2::Oid,
932 cherry_id: git2::Oid,
933 ) -> Result<git2::Oid, git2::Error> {
934 let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
935 git2::Error::new(
936 git2::ErrorCode::NotFound,
937 git2::ErrorClass::Reference,
938 format!("could not find commit {cherry_id:?}"),
939 )
940 })?;
941 let mut cherry_commit = Commit::clone(&cherry_commit);
942 let new_id = self.gen_id();
943 cherry_commit.id = new_id;
944 self.commits
945 .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
946 Ok(new_id)
947 }
948
949 pub fn squash(
950 &mut self,
951 head_id: git2::Oid,
952 into_id: git2::Oid,
953 ) -> Result<git2::Oid, git2::Error> {
954 self.commits.get(&head_id).cloned().ok_or_else(|| {
955 git2::Error::new(
956 git2::ErrorCode::NotFound,
957 git2::ErrorClass::Reference,
958 format!("could not find commit {head_id:?}"),
959 )
960 })?;
961 let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
962 git2::Error::new(
963 git2::ErrorCode::NotFound,
964 git2::ErrorClass::Reference,
965 format!("could not find commit {into_id:?}"),
966 )
967 })?;
968 let intos_parent = intos_parent.unwrap();
969
970 let mut squashed_commit = Commit::clone(&into_commit);
971 let new_id = self.gen_id();
972 squashed_commit.id = new_id;
973 self.commits.insert(
974 new_id,
975 (Some(intos_parent), std::rc::Rc::new(squashed_commit)),
976 );
977 Ok(new_id)
978 }
979
980 pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid, git2::Error> {
981 Err(git2::Error::new(
982 git2::ErrorCode::NotFound,
983 git2::ErrorClass::Reference,
984 "stash is unsupported",
985 ))
986 }
987
988 pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<(), git2::Error> {
989 Err(git2::Error::new(
990 git2::ErrorCode::NotFound,
991 git2::ErrorClass::Reference,
992 "stash is unsupported",
993 ))
994 }
995
996 pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
997 self.branches.insert(
998 name.to_owned(),
999 Branch {
1000 remote: None,
1001 name: name.to_owned(),
1002 id,
1003 push_id: None,
1004 pull_id: None,
1005 },
1006 );
1007 Ok(())
1008 }
1009
1010 pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1011 self.branches.remove(name).map(|_| ()).ok_or_else(|| {
1012 git2::Error::new(
1013 git2::ErrorCode::NotFound,
1014 git2::ErrorClass::Reference,
1015 format!("could not remove branch {name:?}"),
1016 )
1017 })
1018 }
1019
1020 pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1021 self.branches.get(name).cloned()
1022 }
1023
1024 pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1025 None
1026 }
1027
1028 pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1029 self.branches.values().cloned()
1030 }
1031
1032 pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1033 None.into_iter()
1034 }
1035
1036 pub fn detach(&mut self) -> Result<(), git2::Error> {
1037 Ok(())
1038 }
1039
1040 pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1041 let branch = self.find_local_branch(name).ok_or_else(|| {
1042 git2::Error::new(
1043 git2::ErrorCode::NotFound,
1044 git2::ErrorClass::Reference,
1045 format!("could not find branch {name:?}"),
1046 )
1047 })?;
1048 self.head_id = Some(branch.id);
1049 Ok(())
1050 }
1051}
1052
1053impl Default for InMemoryRepo {
1054 fn default() -> Self {
1055 Self::new()
1056 }
1057}
1058
1059struct CommitsFrom<'c> {
1060 commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1061 next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1062}
1063
1064impl Iterator for CommitsFrom<'_> {
1065 type Item = std::rc::Rc<Commit>;
1066
1067 fn next(&mut self) -> Option<Self::Item> {
1068 let mut current = None;
1069 std::mem::swap(&mut current, &mut self.next);
1070 let current = current?;
1071 if let Some(parent_id) = current.0 {
1072 self.next = self.commits.get(&parent_id).cloned();
1073 }
1074 Some(current.1)
1075 }
1076}
1077
1078impl Repo for InMemoryRepo {
1079 fn path(&self) -> Option<&std::path::Path> {
1080 None
1081 }
1082 fn user(&self) -> Option<std::rc::Rc<str>> {
1083 self.user()
1084 }
1085
1086 fn is_dirty(&self) -> bool {
1087 self.is_dirty()
1088 }
1089
1090 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1091 self.merge_base(one, two)
1092 }
1093
1094 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1095 self.find_commit(id)
1096 }
1097
1098 fn head_commit(&self) -> std::rc::Rc<Commit> {
1099 self.head_commit()
1100 }
1101
1102 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1103 self.resolve(revspec)
1104 }
1105
1106 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
1107 self.parent_ids(head_id)
1108 }
1109
1110 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1111 self.commit_count(base_id, head_id)
1112 }
1113
1114 fn commit_range(
1115 &self,
1116 base_bound: std::ops::Bound<&git2::Oid>,
1117 head_bound: std::ops::Bound<&git2::Oid>,
1118 ) -> Result<Vec<git2::Oid>, git2::Error> {
1119 self.commit_range(base_bound, head_bound)
1120 }
1121
1122 fn contains_commit(
1123 &self,
1124 haystack_id: git2::Oid,
1125 needle_id: git2::Oid,
1126 ) -> Result<bool, git2::Error> {
1127 self.contains_commit(haystack_id, needle_id)
1128 }
1129
1130 fn cherry_pick(
1131 &mut self,
1132 head_id: git2::Oid,
1133 cherry_id: git2::Oid,
1134 ) -> Result<git2::Oid, git2::Error> {
1135 self.cherry_pick(head_id, cherry_id)
1136 }
1137
1138 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
1139 self.squash(head_id, into_id)
1140 }
1141
1142 fn head_branch(&self) -> Option<Branch> {
1143 self.head_branch()
1144 }
1145
1146 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
1147 self.stash_push(message)
1148 }
1149
1150 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
1151 self.stash_pop(stash_id)
1152 }
1153
1154 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
1155 self.branch(name, id)
1156 }
1157
1158 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1159 self.delete_branch(name)
1160 }
1161
1162 fn find_local_branch(&self, name: &str) -> Option<Branch> {
1163 self.find_local_branch(name)
1164 }
1165
1166 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1167 self.find_remote_branch(remote, name)
1168 }
1169
1170 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1171 Box::new(self.local_branches())
1172 }
1173
1174 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1175 Box::new(self.remote_branches())
1176 }
1177
1178 fn detach(&mut self) -> Result<(), git2::Error> {
1179 self.detach()
1180 }
1181
1182 fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1183 self.switch(name)
1184 }
1185}
1186
1187pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1188 let branch = repo.head_branch();
1189 let stash_msg = format!(
1190 "WIP on {} ({})",
1191 branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1192 context
1193 );
1194 match repo.stash_push(Some(&stash_msg)) {
1195 Ok(stash_id) => {
1196 log::info!(
1197 "Saved working directory and index state {}: {}",
1198 stash_msg,
1199 stash_id
1200 );
1201 Some(stash_id)
1202 }
1203 Err(err) => {
1204 log::debug!("Failed to stash: {}", err);
1205 None
1206 }
1207 }
1208}
1209
1210pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1211 if let Some(stash_id) = stash_id {
1212 match repo.stash_pop(stash_id) {
1213 Ok(()) => {
1214 log::info!("Dropped refs/stash {}", stash_id);
1215 }
1216 Err(err) => {
1217 log::error!("Failed to pop {} from stash: {}", stash_id, err);
1218 }
1219 }
1220 }
1221}
1222
1223pub fn commit_range(
1224 repo: &dyn Repo,
1225 head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1226) -> Result<Vec<git2::Oid>, git2::Error> {
1227 let head_bound = head_to_base.start_bound();
1228 let base_bound = head_to_base.end_bound();
1229 repo.commit_range(base_bound, head_bound)
1230}