1use bstr::ByteSlice;
2use itertools::Itertools;
3
4pub trait Repo {
5 fn path(&self) -> Option<&std::path::Path>;
6 fn user(&self) -> Option<std::rc::Rc<str>>;
7
8 fn is_dirty(&self) -> bool;
9 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid>;
10
11 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>>;
12 fn head_commit(&self) -> std::rc::Rc<Commit>;
13 fn head_branch(&self) -> Option<Branch>;
14 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>>;
15 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error>;
16 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize>;
17 fn commit_range(
18 &self,
19 base_bound: std::ops::Bound<&git2::Oid>,
20 head_bound: std::ops::Bound<&git2::Oid>,
21 ) -> Result<Vec<git2::Oid>, git2::Error>;
22 fn contains_commit(
23 &self,
24 haystack_id: git2::Oid,
25 needle_id: git2::Oid,
26 ) -> Result<bool, git2::Error>;
27 fn cherry_pick(
28 &mut self,
29 head_id: git2::Oid,
30 cherry_id: git2::Oid,
31 ) -> Result<git2::Oid, git2::Error>;
32 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error>;
33
34 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error>;
35 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error>;
36
37 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error>;
38 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error>;
39 fn find_local_branch(&self, name: &str) -> Option<Branch>;
40 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch>;
41 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
42 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_>;
43 fn detach(&mut self) -> Result<(), git2::Error>;
44 fn switch(&mut self, name: &str) -> Result<(), git2::Error>;
45}
46
47#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct Branch {
49 pub remote: Option<String>,
50 pub name: String,
51 pub id: git2::Oid,
52 pub push_id: Option<git2::Oid>,
53 pub pull_id: Option<git2::Oid>,
54}
55
56impl Branch {
57 pub fn local_name(&self) -> Option<&str> {
58 self.remote.is_none().then_some(self.name.as_str())
59 }
60}
61
62impl std::fmt::Display for Branch {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 if let Some(remote) = self.remote.as_deref() {
65 write!(f, "{}/{}", remote, self.name.as_str())
66 } else {
67 write!(f, "{}", self.name.as_str())
68 }
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct Commit {
74 pub id: git2::Oid,
75 pub tree_id: git2::Oid,
76 pub summary: bstr::BString,
77 pub time: std::time::SystemTime,
78 pub author: Option<std::rc::Rc<str>>,
79 pub committer: Option<std::rc::Rc<str>>,
80}
81
82impl Commit {
83 pub fn fixup_summary(&self) -> Option<&bstr::BStr> {
84 self.summary
85 .strip_prefix(b"fixup! ")
86 .map(ByteSlice::as_bstr)
87 }
88
89 pub fn wip_summary(&self) -> Option<&bstr::BStr> {
90 static WIP_PREFIXES: &[&[u8]] = &[
92 b"WIP:", b"draft:", b"Draft:", b"wip ", b"WIP ", ];
95
96 if self.summary == b"WIP".as_bstr() || self.summary == b"wip".as_bstr() {
97 Some(b"".as_bstr())
99 } else {
100 WIP_PREFIXES.iter().find_map(|prefix| {
101 self.summary
102 .strip_prefix(*prefix)
103 .map(ByteSlice::trim)
104 .map(ByteSlice::as_bstr)
105 })
106 }
107 }
108
109 pub fn revert_summary(&self) -> Option<&bstr::BStr> {
110 self.summary
111 .strip_prefix(b"Revert ")
112 .and_then(|s| s.strip_suffix(b"\""))
113 .map(ByteSlice::as_bstr)
114 }
115}
116
117pub struct GitRepo {
118 repo: git2::Repository,
119 sign: Option<git2_ext::ops::UserSign>,
120 push_remote: Option<String>,
121 pull_remote: Option<String>,
122 commits: std::cell::RefCell<std::collections::HashMap<git2::Oid, std::rc::Rc<Commit>>>,
123 interned_strings: std::cell::RefCell<std::collections::HashSet<std::rc::Rc<str>>>,
124 bases: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<git2::Oid>>>,
125 counts: std::cell::RefCell<std::collections::HashMap<(git2::Oid, git2::Oid), Option<usize>>>,
126}
127
128impl GitRepo {
129 pub fn new(repo: git2::Repository) -> Self {
130 Self {
131 repo,
132 sign: None,
133 push_remote: None,
134 pull_remote: None,
135 commits: Default::default(),
136 interned_strings: Default::default(),
137 bases: Default::default(),
138 counts: Default::default(),
139 }
140 }
141
142 pub fn set_sign(&mut self, yes: bool) -> Result<(), git2::Error> {
143 if yes {
144 let config = self.repo.config()?;
145 let sign = git2_ext::ops::UserSign::from_config(&self.repo, &config)?;
146 self.sign = Some(sign);
147 } else {
148 self.sign = None;
149 }
150 Ok(())
151 }
152
153 pub fn set_push_remote(&mut self, remote: &str) {
154 self.push_remote = Some(remote.to_owned());
155 }
156
157 pub fn set_pull_remote(&mut self, remote: &str) {
158 self.pull_remote = Some(remote.to_owned());
159 }
160
161 pub fn push_remote(&self) -> &str {
162 self.push_remote.as_deref().unwrap_or("origin")
163 }
164
165 pub fn pull_remote(&self) -> &str {
166 self.pull_remote.as_deref().unwrap_or("origin")
167 }
168
169 pub fn raw(&self) -> &git2::Repository {
170 &self.repo
171 }
172
173 pub fn user(&self) -> Option<std::rc::Rc<str>> {
174 self.repo
175 .signature()
176 .ok()
177 .and_then(|s| s.name().map(|n| self.intern_string(n)))
178 }
179
180 pub fn is_dirty(&self) -> bool {
181 if self.repo.state() != git2::RepositoryState::Clean {
182 log::trace!("Repository status is unclean: {:?}", self.repo.state());
183 return true;
184 }
185
186 let status = self
187 .repo
188 .statuses(Some(git2::StatusOptions::new().include_ignored(false)))
189 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
190 if status.is_empty() {
191 false
192 } else {
193 log::trace!(
194 "Repository is dirty: {}",
195 status
196 .iter()
197 .filter_map(|s| s.path().map(|s| s.to_owned()))
198 .join(", ")
199 );
200 true
201 }
202 }
203
204 pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
205 if one == two {
206 return Some(one);
207 }
208
209 let (smaller, larger) = if one < two { (one, two) } else { (two, one) };
210 *self
211 .bases
212 .borrow_mut()
213 .entry((smaller, larger))
214 .or_insert_with(|| self.merge_base_raw(smaller, larger))
215 }
216
217 fn merge_base_raw(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
218 self.repo.merge_base(one, two).ok()
219 }
220
221 pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
222 let mut commits = self.commits.borrow_mut();
223 if let Some(commit) = commits.get(&id) {
224 Some(std::rc::Rc::clone(commit))
225 } else {
226 let commit = self.repo.find_commit(id).ok()?;
227 let summary: bstr::BString = commit.summary_bytes().unwrap().into();
228 let time = std::time::SystemTime::UNIX_EPOCH
229 + std::time::Duration::from_secs(commit.time().seconds().max(0) as u64);
230
231 let author = commit.author().name().map(|n| self.intern_string(n));
232 let committer = commit.author().name().map(|n| self.intern_string(n));
233 let commit = std::rc::Rc::new(Commit {
234 id: commit.id(),
235 tree_id: commit.tree_id(),
236 summary,
237 time,
238 author,
239 committer,
240 });
241 commits.insert(id, std::rc::Rc::clone(&commit));
242 Some(commit)
243 }
244 }
245
246 pub fn head_commit(&self) -> std::rc::Rc<Commit> {
247 let head_id = self
248 .repo
249 .head()
250 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
251 .resolve()
252 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
253 .target()
254 .unwrap();
255 self.find_commit(head_id).unwrap()
256 }
257
258 pub fn head_branch(&self) -> Option<Branch> {
259 let resolved = self
260 .repo
261 .head()
262 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
263 .resolve()
264 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
265 let name = resolved.shorthand()?;
266 let id = resolved.target()?;
267
268 let push_id = self
269 .repo
270 .find_branch(
271 &format!("{}/{}", self.push_remote(), name),
272 git2::BranchType::Remote,
273 )
274 .ok()
275 .and_then(|b| b.get().target());
276 let pull_id = self
277 .repo
278 .find_branch(
279 &format!("{}/{}", self.pull_remote(), name),
280 git2::BranchType::Remote,
281 )
282 .ok()
283 .and_then(|b| b.get().target());
284
285 Some(Branch {
286 remote: None,
287 name: name.to_owned(),
288 id,
289 push_id,
290 pull_id,
291 })
292 }
293
294 pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
295 let id = self.repo.revparse_single(revspec).ok()?.id();
296 self.find_commit(id)
297 }
298
299 pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
300 let commit = self.repo.find_commit(head_id)?;
301 Ok(commit.parent_ids().collect())
302 }
303
304 pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
305 if base_id == head_id {
306 return Some(0);
307 }
308
309 *self
310 .counts
311 .borrow_mut()
312 .entry((base_id, head_id))
313 .or_insert_with(|| self.commit_count_raw(base_id, head_id))
314 }
315
316 fn commit_count_raw(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
317 let merge_base_id = self.merge_base(base_id, head_id)?;
318 if merge_base_id != base_id {
319 return None;
320 }
321 let mut revwalk = self
322 .repo
323 .revwalk()
324 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
325 revwalk
326 .push(head_id)
327 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
328 revwalk
329 .hide(base_id)
330 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
331 Some(revwalk.count())
332 }
333
334 pub fn commit_range(
335 &self,
336 base_bound: std::ops::Bound<&git2::Oid>,
337 head_bound: std::ops::Bound<&git2::Oid>,
338 ) -> Result<Vec<git2::Oid>, git2::Error> {
339 let head_id = match head_bound {
340 std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
341 std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
342 };
343 let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
344 0
345 } else {
346 1
347 };
348
349 let base_id = match base_bound {
350 std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
351 debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
352 Some(*base_id)
353 }
354 std::ops::Bound::Unbounded => None,
355 };
356
357 let mut revwalk = self.repo.revwalk()?;
358 revwalk.push(head_id)?;
359 if let Some(base_id) = base_id {
360 revwalk.hide(base_id)?;
361 }
362 revwalk.set_sorting(git2::Sort::TOPOLOGICAL)?;
363 let mut result = revwalk
364 .filter_map(Result::ok)
365 .skip(skip)
366 .take_while(|id| Some(*id) != base_id)
367 .collect::<Vec<_>>();
368 if let std::ops::Bound::Included(base_id) = base_bound {
369 result.push(*base_id);
370 }
371 Ok(result)
372 }
373
374 pub fn contains_commit(
375 &self,
376 haystack_id: git2::Oid,
377 needle_id: git2::Oid,
378 ) -> Result<bool, git2::Error> {
379 let needle_commit = self.repo.find_commit(needle_id)?;
380 let needle_ann_commit = self.repo.find_annotated_commit(needle_id)?;
381 let haystack_ann_commit = self.repo.find_annotated_commit(haystack_id)?;
382
383 let parent_ann_commit = if 0 < needle_commit.parent_count() {
384 let parent_commit = needle_commit.parent(0)?;
385 Some(self.repo.find_annotated_commit(parent_commit.id())?)
386 } else {
387 None
388 };
389
390 let mut rebase = self.repo.rebase(
391 Some(&needle_ann_commit),
392 parent_ann_commit.as_ref(),
393 Some(&haystack_ann_commit),
394 Some(git2::RebaseOptions::new().inmemory(true)),
395 )?;
396
397 if let Some(op) = rebase.next() {
398 op.inspect_err(|_e| {
399 let _ = rebase.abort();
400 })?;
401 let inmemory_index = rebase
402 .inmemory_index()
403 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
404 if inmemory_index.has_conflicts() {
405 return Ok(false);
406 }
407
408 let sig = self
409 .repo
410 .signature()
411 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"));
412 let result = rebase.commit(None, &sig, None).inspect_err(|_e| {
413 let _ = rebase.abort();
414 });
415 match result {
416 Ok(_) => Ok(false),
418 Err(err) => {
419 if err.class() == git2::ErrorClass::Rebase
420 && err.code() == git2::ErrorCode::Applied
421 {
422 return Ok(true);
423 }
424 Err(err)
425 }
426 }
427 } else {
428 rebase.finish(None)?;
430 Ok(true)
431 }
432 }
433
434 fn cherry_pick(
435 &mut self,
436 head_id: git2::Oid,
437 cherry_id: git2::Oid,
438 ) -> Result<git2::Oid, git2::Error> {
439 git2_ext::ops::cherry_pick(
440 &self.repo,
441 head_id,
442 cherry_id,
443 self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
444 )
445 }
446
447 pub fn squash(
448 &mut self,
449 head_id: git2::Oid,
450 into_id: git2::Oid,
451 ) -> Result<git2::Oid, git2::Error> {
452 git2_ext::ops::squash(
453 &self.repo,
454 head_id,
455 into_id,
456 self.sign.as_ref().map(|s| s as &dyn git2_ext::ops::Sign),
457 )
458 }
459
460 pub fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
461 let signature = self.repo.signature()?;
462 self.repo.stash_save2(&signature, message, None)
463 }
464
465 pub fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
466 let mut index = None;
467 self.repo.stash_foreach(|i, _, id| {
468 if *id == stash_id {
469 index = Some(i);
470 false
471 } else {
472 true
473 }
474 })?;
475 let index = index.ok_or_else(|| {
476 git2::Error::new(
477 git2::ErrorCode::NotFound,
478 git2::ErrorClass::Reference,
479 "stash ID not found",
480 )
481 })?;
482 self.repo.stash_pop(index, None)
483 }
484
485 pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
486 let commit = self.repo.find_commit(id)?;
487 self.repo.branch(name, &commit, true)?;
488 Ok(())
489 }
490
491 pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
492 let mut branch = self.repo.find_branch(name, git2::BranchType::Local)?;
494 branch.delete()
495 }
496
497 pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
498 let branch = self.repo.find_branch(name, git2::BranchType::Local).ok()?;
499 self.load_local_branch(&branch, name).ok()
500 }
501
502 pub fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
503 let qualified = format!("{remote}/{name}");
504 let branch = self
505 .repo
506 .find_branch(&qualified, git2::BranchType::Remote)
507 .ok()?;
508 self.load_remote_branch(&branch, remote, name).ok()
509 }
510
511 pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
512 log::trace!("Loading local branches");
513 self.repo
514 .branches(Some(git2::BranchType::Local))
515 .into_iter()
516 .flatten()
517 .filter_map(move |branch| {
518 let (branch, _) = branch.ok()?;
519 let name = if let Some(name) = branch.name().ok().flatten() {
520 name
521 } else {
522 log::debug!(
523 "Ignoring non-UTF8 branch {:?}",
524 branch.name_bytes().unwrap().as_bstr()
525 );
526 return None;
527 };
528 self.load_local_branch(&branch, name).ok()
529 })
530 }
531
532 pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
533 log::trace!("Loading remote branches");
534 self.repo
535 .branches(Some(git2::BranchType::Remote))
536 .into_iter()
537 .flatten()
538 .filter_map(move |branch| {
539 let (branch, _) = branch.ok()?;
540 let name = if let Some(name) = branch.name().ok().flatten() {
541 name
542 } else {
543 log::debug!(
544 "Ignoring non-UTF8 branch {:?}",
545 branch.name_bytes().unwrap().as_bstr()
546 );
547 return None;
548 };
549 let (remote, name) = name.split_once('/').unwrap();
550 self.load_remote_branch(&branch, remote, name).ok()
551 })
552 }
553
554 fn load_local_branch(
555 &self,
556 branch: &git2::Branch<'_>,
557 name: &str,
558 ) -> Result<Branch, git2::Error> {
559 let id = branch.get().target().unwrap();
560
561 let push_id = self
562 .repo
563 .find_branch(
564 &format!("{}/{}", self.push_remote(), name),
565 git2::BranchType::Remote,
566 )
567 .ok()
568 .and_then(|b| b.get().target());
569 let pull_id = self
570 .repo
571 .find_branch(
572 &format!("{}/{}", self.pull_remote(), name),
573 git2::BranchType::Remote,
574 )
575 .ok()
576 .and_then(|b| b.get().target());
577
578 Ok(Branch {
579 remote: None,
580 name: name.to_owned(),
581 id,
582 push_id,
583 pull_id,
584 })
585 }
586
587 fn load_remote_branch(
588 &self,
589 branch: &git2::Branch<'_>,
590 remote: &str,
591 name: &str,
592 ) -> Result<Branch, git2::Error> {
593 let id = branch.get().target().unwrap();
594
595 let push_id = (remote == self.push_remote()).then_some(id);
596 let pull_id = (remote == self.pull_remote()).then_some(id);
597
598 Ok(Branch {
599 remote: Some(remote.to_owned()),
600 name: name.to_owned(),
601 id,
602 push_id,
603 pull_id,
604 })
605 }
606
607 pub fn detach(&mut self) -> Result<(), git2::Error> {
608 let head_id = self
609 .repo
610 .head()
611 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
612 .resolve()
613 .unwrap_or_else(|e| panic!("Unexpected git2 error: {e}"))
614 .target()
615 .unwrap();
616 self.repo.set_head_detached(head_id)?;
617 Ok(())
618 }
619
620 pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
621 let branch = self.repo.find_branch(name, git2::BranchType::Local)?;
623 self.repo.set_head(branch.get().name().unwrap())?;
624 let mut builder = git2::build::CheckoutBuilder::new();
625 builder.force();
626 self.repo.checkout_head(Some(&mut builder))?;
627 Ok(())
628 }
629
630 fn intern_string(&self, data: &str) -> std::rc::Rc<str> {
631 let mut interned_strings = self.interned_strings.borrow_mut();
632 if let Some(interned) = interned_strings.get(data) {
633 std::rc::Rc::clone(interned)
634 } else {
635 let interned = std::rc::Rc::from(data);
636 interned_strings.insert(std::rc::Rc::clone(&interned));
637 interned
638 }
639 }
640}
641
642impl std::fmt::Debug for GitRepo {
643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
644 f.debug_struct("GitRepo")
645 .field("repo", &self.repo.workdir())
646 .field("push_remote", &self.push_remote.as_deref())
647 .field("pull_remote", &self.pull_remote.as_deref())
648 .finish()
649 }
650}
651
652impl Repo for GitRepo {
653 fn path(&self) -> Option<&std::path::Path> {
654 Some(self.repo.path())
655 }
656 fn user(&self) -> Option<std::rc::Rc<str>> {
657 self.user()
658 }
659
660 fn is_dirty(&self) -> bool {
661 self.is_dirty()
662 }
663
664 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
665 self.merge_base(one, two)
666 }
667
668 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
669 self.find_commit(id)
670 }
671
672 fn head_commit(&self) -> std::rc::Rc<Commit> {
673 self.head_commit()
674 }
675
676 fn head_branch(&self) -> Option<Branch> {
677 self.head_branch()
678 }
679
680 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
681 self.resolve(revspec)
682 }
683
684 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
685 self.parent_ids(head_id)
686 }
687
688 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
689 self.commit_count(base_id, head_id)
690 }
691
692 fn commit_range(
693 &self,
694 base_bound: std::ops::Bound<&git2::Oid>,
695 head_bound: std::ops::Bound<&git2::Oid>,
696 ) -> Result<Vec<git2::Oid>, git2::Error> {
697 self.commit_range(base_bound, head_bound)
698 }
699
700 fn contains_commit(
701 &self,
702 haystack_id: git2::Oid,
703 needle_id: git2::Oid,
704 ) -> Result<bool, git2::Error> {
705 self.contains_commit(haystack_id, needle_id)
706 }
707
708 fn cherry_pick(
709 &mut self,
710 head_id: git2::Oid,
711 cherry_id: git2::Oid,
712 ) -> Result<git2::Oid, git2::Error> {
713 self.cherry_pick(head_id, cherry_id)
714 }
715
716 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
717 self.squash(head_id, into_id)
718 }
719
720 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
721 self.stash_push(message)
722 }
723
724 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
725 self.stash_pop(stash_id)
726 }
727
728 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
729 self.branch(name, id)
730 }
731
732 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
733 self.delete_branch(name)
734 }
735
736 fn find_local_branch(&self, name: &str) -> Option<Branch> {
737 self.find_local_branch(name)
738 }
739
740 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
741 self.find_remote_branch(remote, name)
742 }
743
744 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
745 Box::new(self.local_branches())
746 }
747
748 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
749 Box::new(self.remote_branches())
750 }
751
752 fn detach(&mut self) -> Result<(), git2::Error> {
753 self.detach()
754 }
755
756 fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
757 self.switch(name)
758 }
759}
760
761#[derive(Debug)]
762pub struct InMemoryRepo {
763 commits: std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
764 branches: std::collections::HashMap<String, Branch>,
765 head_id: Option<git2::Oid>,
766
767 last_id: std::sync::atomic::AtomicUsize,
768}
769
770impl InMemoryRepo {
771 pub fn new() -> Self {
772 Self {
773 commits: Default::default(),
774 branches: Default::default(),
775 head_id: Default::default(),
776 last_id: std::sync::atomic::AtomicUsize::new(1),
777 }
778 }
779
780 pub fn clear(&mut self) {
781 *self = InMemoryRepo::new();
782 }
783
784 pub fn gen_id(&mut self) -> git2::Oid {
785 let last_id = self
786 .last_id
787 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
788 let sha = format!("{last_id:040x}");
789 git2::Oid::from_str(&sha).unwrap()
790 }
791
792 pub fn push_commit(&mut self, parent_id: Option<git2::Oid>, commit: Commit) {
793 if let Some(parent_id) = parent_id {
794 assert!(self.commits.contains_key(&parent_id));
795 }
796 self.head_id = Some(commit.id);
797 self.commits
798 .insert(commit.id, (parent_id, std::rc::Rc::new(commit)));
799 }
800
801 pub fn head_id(&mut self) -> Option<git2::Oid> {
802 self.head_id
803 }
804
805 pub fn set_head(&mut self, head_id: git2::Oid) {
806 assert!(self.commits.contains_key(&head_id));
807 self.head_id = Some(head_id);
808 }
809
810 pub fn mark_branch(&mut self, branch: Branch) {
811 assert!(self.commits.contains_key(&branch.id));
812 self.branches.insert(branch.name.clone(), branch);
813 }
814
815 fn user(&self) -> Option<std::rc::Rc<str>> {
816 None
817 }
818
819 pub fn is_dirty(&self) -> bool {
820 false
821 }
822
823 pub fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
824 let one_ancestors: Vec<_> = self.commits_from(one).collect();
825 self.commits_from(two)
826 .filter(|two_ancestor| one_ancestors.contains(two_ancestor))
827 .map(|c| c.id)
828 .next()
829 }
830
831 pub fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
832 self.commits.get(&id).map(|c| c.1.clone())
833 }
834
835 pub fn head_commit(&self) -> std::rc::Rc<Commit> {
836 self.commits.get(&self.head_id.unwrap()).cloned().unwrap().1
837 }
838
839 pub fn head_branch(&self) -> Option<Branch> {
840 self.branches
841 .values()
842 .find(|b| b.id == self.head_id.unwrap())
843 .cloned()
844 }
845
846 pub fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
847 let branch = self.branches.get(revspec)?;
848 self.find_commit(branch.id)
849 }
850
851 pub fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
852 let next = self
853 .commits
854 .get(&head_id)
855 .and_then(|(parent, _commit)| *parent);
856 Ok(next.into_iter().collect())
857 }
858
859 fn commits_from(&self, head_id: git2::Oid) -> impl Iterator<Item = std::rc::Rc<Commit>> + '_ {
860 let next = self.commits.get(&head_id).cloned();
861 CommitsFrom {
862 commits: &self.commits,
863 next,
864 }
865 }
866
867 pub fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
868 let merge_base_id = self.merge_base(base_id, head_id)?;
869 let count = self
870 .commits_from(head_id)
871 .take_while(move |cur_id| cur_id.id != merge_base_id)
872 .count();
873 Some(count)
874 }
875
876 pub fn commit_range(
877 &self,
878 base_bound: std::ops::Bound<&git2::Oid>,
879 head_bound: std::ops::Bound<&git2::Oid>,
880 ) -> Result<Vec<git2::Oid>, git2::Error> {
881 let head_id = match head_bound {
882 std::ops::Bound::Included(head_id) | std::ops::Bound::Excluded(head_id) => *head_id,
883 std::ops::Bound::Unbounded => panic!("commit_range's HEAD cannot be unbounded"),
884 };
885 let skip = if matches!(head_bound, std::ops::Bound::Included(_)) {
886 0
887 } else {
888 1
889 };
890
891 let base_id = match base_bound {
892 std::ops::Bound::Included(base_id) | std::ops::Bound::Excluded(base_id) => {
893 debug_assert_eq!(self.merge_base(*base_id, head_id), Some(*base_id));
894 Some(*base_id)
895 }
896 std::ops::Bound::Unbounded => None,
897 };
898
899 let mut result = self
900 .commits_from(head_id)
901 .skip(skip)
902 .map(|commit| commit.id)
903 .take_while(|id| Some(*id) != base_id)
904 .collect::<Vec<_>>();
905 if let std::ops::Bound::Included(base_id) = base_bound {
906 result.push(*base_id);
907 }
908 Ok(result)
909 }
910
911 pub fn contains_commit(
912 &self,
913 haystack_id: git2::Oid,
914 needle_id: git2::Oid,
915 ) -> Result<bool, git2::Error> {
916 let mut next = Some(haystack_id);
918 while let Some(current) = next {
919 if current == needle_id {
920 return Ok(true);
921 }
922 next = self.commits.get(¤t).and_then(|c| c.0);
923 }
924 Ok(false)
925 }
926
927 pub fn cherry_pick(
928 &mut self,
929 head_id: git2::Oid,
930 cherry_id: git2::Oid,
931 ) -> Result<git2::Oid, git2::Error> {
932 let cherry_commit = self.find_commit(cherry_id).ok_or_else(|| {
933 git2::Error::new(
934 git2::ErrorCode::NotFound,
935 git2::ErrorClass::Reference,
936 format!("could not find commit {cherry_id:?}"),
937 )
938 })?;
939 let mut cherry_commit = Commit::clone(&cherry_commit);
940 let new_id = self.gen_id();
941 cherry_commit.id = new_id;
942 self.commits
943 .insert(new_id, (Some(head_id), std::rc::Rc::new(cherry_commit)));
944 Ok(new_id)
945 }
946
947 pub fn squash(
948 &mut self,
949 head_id: git2::Oid,
950 into_id: git2::Oid,
951 ) -> Result<git2::Oid, git2::Error> {
952 self.commits.get(&head_id).cloned().ok_or_else(|| {
953 git2::Error::new(
954 git2::ErrorCode::NotFound,
955 git2::ErrorClass::Reference,
956 format!("could not find commit {head_id:?}"),
957 )
958 })?;
959 let (intos_parent, into_commit) = self.commits.get(&into_id).cloned().ok_or_else(|| {
960 git2::Error::new(
961 git2::ErrorCode::NotFound,
962 git2::ErrorClass::Reference,
963 format!("could not find commit {into_id:?}"),
964 )
965 })?;
966 let intos_parent = intos_parent.unwrap();
967
968 let mut squashed_commit = Commit::clone(&into_commit);
969 let new_id = self.gen_id();
970 squashed_commit.id = new_id;
971 self.commits.insert(
972 new_id,
973 (Some(intos_parent), std::rc::Rc::new(squashed_commit)),
974 );
975 Ok(new_id)
976 }
977
978 pub fn stash_push(&mut self, _message: Option<&str>) -> Result<git2::Oid, git2::Error> {
979 Err(git2::Error::new(
980 git2::ErrorCode::NotFound,
981 git2::ErrorClass::Reference,
982 "stash is unsupported",
983 ))
984 }
985
986 pub fn stash_pop(&mut self, _stash_id: git2::Oid) -> Result<(), git2::Error> {
987 Err(git2::Error::new(
988 git2::ErrorCode::NotFound,
989 git2::ErrorClass::Reference,
990 "stash is unsupported",
991 ))
992 }
993
994 pub fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
995 self.branches.insert(
996 name.to_owned(),
997 Branch {
998 remote: None,
999 name: name.to_owned(),
1000 id,
1001 push_id: None,
1002 pull_id: None,
1003 },
1004 );
1005 Ok(())
1006 }
1007
1008 pub fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1009 self.branches.remove(name).map(|_| ()).ok_or_else(|| {
1010 git2::Error::new(
1011 git2::ErrorCode::NotFound,
1012 git2::ErrorClass::Reference,
1013 format!("could not remove branch {name:?}"),
1014 )
1015 })
1016 }
1017
1018 pub fn find_local_branch(&self, name: &str) -> Option<Branch> {
1019 self.branches.get(name).cloned()
1020 }
1021
1022 pub fn find_remote_branch(&self, _remote: &str, _name: &str) -> Option<Branch> {
1023 None
1024 }
1025
1026 pub fn local_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1027 self.branches.values().cloned()
1028 }
1029
1030 pub fn remote_branches(&self) -> impl Iterator<Item = Branch> + '_ {
1031 None.into_iter()
1032 }
1033
1034 pub fn detach(&mut self) -> Result<(), git2::Error> {
1035 Ok(())
1036 }
1037
1038 pub fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1039 let branch = self.find_local_branch(name).ok_or_else(|| {
1040 git2::Error::new(
1041 git2::ErrorCode::NotFound,
1042 git2::ErrorClass::Reference,
1043 format!("could not find branch {name:?}"),
1044 )
1045 })?;
1046 self.head_id = Some(branch.id);
1047 Ok(())
1048 }
1049}
1050
1051impl Default for InMemoryRepo {
1052 fn default() -> Self {
1053 Self::new()
1054 }
1055}
1056
1057struct CommitsFrom<'c> {
1058 commits: &'c std::collections::HashMap<git2::Oid, (Option<git2::Oid>, std::rc::Rc<Commit>)>,
1059 next: Option<(Option<git2::Oid>, std::rc::Rc<Commit>)>,
1060}
1061
1062impl Iterator for CommitsFrom<'_> {
1063 type Item = std::rc::Rc<Commit>;
1064
1065 fn next(&mut self) -> Option<Self::Item> {
1066 let mut current = None;
1067 std::mem::swap(&mut current, &mut self.next);
1068 let current = current?;
1069 if let Some(parent_id) = current.0 {
1070 self.next = self.commits.get(&parent_id).cloned();
1071 }
1072 Some(current.1)
1073 }
1074}
1075
1076impl Repo for InMemoryRepo {
1077 fn path(&self) -> Option<&std::path::Path> {
1078 None
1079 }
1080 fn user(&self) -> Option<std::rc::Rc<str>> {
1081 self.user()
1082 }
1083
1084 fn is_dirty(&self) -> bool {
1085 self.is_dirty()
1086 }
1087
1088 fn merge_base(&self, one: git2::Oid, two: git2::Oid) -> Option<git2::Oid> {
1089 self.merge_base(one, two)
1090 }
1091
1092 fn find_commit(&self, id: git2::Oid) -> Option<std::rc::Rc<Commit>> {
1093 self.find_commit(id)
1094 }
1095
1096 fn head_commit(&self) -> std::rc::Rc<Commit> {
1097 self.head_commit()
1098 }
1099
1100 fn resolve(&self, revspec: &str) -> Option<std::rc::Rc<Commit>> {
1101 self.resolve(revspec)
1102 }
1103
1104 fn parent_ids(&self, head_id: git2::Oid) -> Result<Vec<git2::Oid>, git2::Error> {
1105 self.parent_ids(head_id)
1106 }
1107
1108 fn commit_count(&self, base_id: git2::Oid, head_id: git2::Oid) -> Option<usize> {
1109 self.commit_count(base_id, head_id)
1110 }
1111
1112 fn commit_range(
1113 &self,
1114 base_bound: std::ops::Bound<&git2::Oid>,
1115 head_bound: std::ops::Bound<&git2::Oid>,
1116 ) -> Result<Vec<git2::Oid>, git2::Error> {
1117 self.commit_range(base_bound, head_bound)
1118 }
1119
1120 fn contains_commit(
1121 &self,
1122 haystack_id: git2::Oid,
1123 needle_id: git2::Oid,
1124 ) -> Result<bool, git2::Error> {
1125 self.contains_commit(haystack_id, needle_id)
1126 }
1127
1128 fn cherry_pick(
1129 &mut self,
1130 head_id: git2::Oid,
1131 cherry_id: git2::Oid,
1132 ) -> Result<git2::Oid, git2::Error> {
1133 self.cherry_pick(head_id, cherry_id)
1134 }
1135
1136 fn squash(&mut self, head_id: git2::Oid, into_id: git2::Oid) -> Result<git2::Oid, git2::Error> {
1137 self.squash(head_id, into_id)
1138 }
1139
1140 fn head_branch(&self) -> Option<Branch> {
1141 self.head_branch()
1142 }
1143
1144 fn stash_push(&mut self, message: Option<&str>) -> Result<git2::Oid, git2::Error> {
1145 self.stash_push(message)
1146 }
1147
1148 fn stash_pop(&mut self, stash_id: git2::Oid) -> Result<(), git2::Error> {
1149 self.stash_pop(stash_id)
1150 }
1151
1152 fn branch(&mut self, name: &str, id: git2::Oid) -> Result<(), git2::Error> {
1153 self.branch(name, id)
1154 }
1155
1156 fn delete_branch(&mut self, name: &str) -> Result<(), git2::Error> {
1157 self.delete_branch(name)
1158 }
1159
1160 fn find_local_branch(&self, name: &str) -> Option<Branch> {
1161 self.find_local_branch(name)
1162 }
1163
1164 fn find_remote_branch(&self, remote: &str, name: &str) -> Option<Branch> {
1165 self.find_remote_branch(remote, name)
1166 }
1167
1168 fn local_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1169 Box::new(self.local_branches())
1170 }
1171
1172 fn remote_branches(&self) -> Box<dyn Iterator<Item = Branch> + '_> {
1173 Box::new(self.remote_branches())
1174 }
1175
1176 fn detach(&mut self) -> Result<(), git2::Error> {
1177 self.detach()
1178 }
1179
1180 fn switch(&mut self, name: &str) -> Result<(), git2::Error> {
1181 self.switch(name)
1182 }
1183}
1184
1185pub fn stash_push(repo: &mut dyn Repo, context: &str) -> Option<git2::Oid> {
1186 let branch = repo.head_branch();
1187 let stash_msg = format!(
1188 "WIP on {} ({})",
1189 branch.as_ref().map(|b| b.name.as_str()).unwrap_or("HEAD"),
1190 context
1191 );
1192 match repo.stash_push(Some(&stash_msg)) {
1193 Ok(stash_id) => {
1194 log::info!(
1195 "Saved working directory and index state {}: {}",
1196 stash_msg,
1197 stash_id
1198 );
1199 Some(stash_id)
1200 }
1201 Err(err) => {
1202 log::debug!("Failed to stash: {}", err);
1203 None
1204 }
1205 }
1206}
1207
1208pub fn stash_pop(repo: &mut dyn Repo, stash_id: Option<git2::Oid>) {
1209 if let Some(stash_id) = stash_id {
1210 match repo.stash_pop(stash_id) {
1211 Ok(()) => {
1212 log::info!("Dropped refs/stash {}", stash_id);
1213 }
1214 Err(err) => {
1215 log::error!("Failed to pop {} from stash: {}", stash_id, err);
1216 }
1217 }
1218 }
1219}
1220
1221pub fn commit_range(
1222 repo: &dyn Repo,
1223 head_to_base: impl std::ops::RangeBounds<git2::Oid>,
1224) -> Result<Vec<git2::Oid>, git2::Error> {
1225 let head_bound = head_to_base.start_bound();
1226 let base_bound = head_to_base.end_bound();
1227 repo.commit_range(base_bound, head_bound)
1228}