1use std::{fmt, path};
2
3use anyhow::*;
4use git2::StatusOptions;
5use git2::build::CheckoutBuilder;
6use std::result::Result::Ok;
7
8#[derive(Debug, Clone, PartialEq)]
9pub enum MergeResult {
10 UpToDate,
11 FastForward,
12 Merged,
13 Rebased,
14 Conflicts,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum RemoteComparison {
19 UpToDate,
20 Ahead(usize),
21 Behind(usize),
22 Diverged(usize, usize),
23 NoRemote,
24}
25
26pub struct Repo {
27 pub git_repo: git2::Repository,
28 pub work_dir: path::PathBuf,
29 pub head: String,
30 pub subrepos: Vec<Repo>,
31}
32
33impl Repo {
34 pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
35 let git_repo = git2::Repository::open(work_dir)
36 .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
37
38 let head = match head_name {
39 Some(name) => String::from(name),
40 None => {
41 let is_detached = git_repo.head_detached().with_context(|| {
42 format!(
43 "Cannot determine head state for repo at `{}`",
44 work_dir.display()
45 )
46 })?;
47 if is_detached {
48 String::from("<detached>")
49 } else {
50 String::from(git_repo.head().with_context(|| {
51 format!(
52 "Cannot find the head branch for repo at `{}`. Is it detached?",
53 work_dir.display()
54 )
55 })?.shorthand().with_context(|| {
56 format!(
57 "Cannot find a human readable representation of the head ref for repo at `{}`",
58 work_dir.display(),
59 )
60 })?)
61 }
62 },
63 };
64
65 let subrepos = git_repo
66 .submodules()
67 .with_context(|| {
68 format!(
69 "Cannot load submodules for repo at `{}`",
70 work_dir.display()
71 )
72 })?
73 .iter()
74 .map(|submodule| Repo::new(&work_dir.join(submodule.path()), None))
75 .collect::<Result<Vec<Repo>>>()?;
76
77 Ok(Repo {
78 git_repo,
79 work_dir: path::PathBuf::from(work_dir),
80 head,
81 subrepos,
82 })
83 }
84
85 pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
86 self.subrepos
87 .iter()
88 .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
89 }
90
91 pub fn sync(&self) -> Result<()> {
92 self.switch(&self.head)?;
93 Ok(())
94 }
95
96 pub fn switch(&self, head: &str) -> Result<()> {
97 self.git_repo.set_head(&self.resolve_reference(head)?)?;
98 let checkout_result = self.git_repo.checkout_head(None);
99 checkout_result?;
100 Ok(())
101 }
102
103 pub fn switch_force(&self, head: &str) -> Result<()> {
104 self.git_repo.set_head(&self.resolve_reference(head)?)?;
105 let mut checkout = CheckoutBuilder::new();
106 checkout.force();
107 let checkout_result = self.git_repo.checkout_head(Some(&mut checkout));
108 checkout_result?;
109 Ok(())
110 }
111
112 pub fn refresh_worktree(&self) -> Result<()> {
113 let checkout_result = self.git_repo.checkout_head(None);
114 checkout_result?;
115 Ok(())
116 }
117
118 pub fn refresh_worktree_force(&self) -> Result<()> {
119 let mut checkout = CheckoutBuilder::new();
120 checkout.force();
121 let checkout_result = self.git_repo.checkout_head(Some(&mut checkout));
122 checkout_result?;
123 Ok(())
124 }
125
126 pub fn checkout_path_from_head(&self, path: &path::Path) -> Result<()> {
127 let mut checkout = CheckoutBuilder::new();
128 checkout.force().path(path);
129 self.git_repo.checkout_head(Some(&mut checkout))?;
130 Ok(())
131 }
132
133 fn switch_forced(&self, head: &str) -> Result<()> {
134 self.git_repo.set_head(&self.resolve_reference(head)?)?;
135 let mut checkout = CheckoutBuilder::new();
136 checkout.force();
137 self.git_repo.checkout_head(Some(&mut checkout))?;
138 Ok(())
139 }
140
141 pub fn fetch(&self) -> Result<()> {
142 if self.git_repo.head_detached().with_context(|| {
143 format!(
144 "Cannot determine head state for repo at `{}`",
145 self.work_dir.display()
146 )
147 })? {
148 return Ok(());
149 }
150
151 let head_ref = self.git_repo.head()?;
153 let branch_name = head_ref.shorthand().with_context(|| {
154 format!(
155 "Cannot get branch name for repo at `{}`",
156 self.work_dir.display()
157 )
158 })?;
159
160 let tracking = match self.tracking_branch(branch_name)? {
161 Some(tracking) => tracking,
162 None => {
163 return Ok(());
165 },
166 };
167
168 match self.git_repo.find_remote(&tracking.remote) {
170 Ok(mut remote) => {
171 let mut fetch_options = git2::FetchOptions::new();
172 fetch_options.remote_callbacks(self.remote_callbacks()?);
173
174 remote
175 .fetch::<&str>(&[], Some(&mut fetch_options), None)
176 .with_context(|| {
177 format!(
178 "Failed to fetch from remote '{}' for repo at `{}`\n\
179 \n\
180 Possible causes:\n\
181 - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
182 - SSH keys not properly configured in ~/.ssh/\n\
183 - Credential helper not configured (git config credential.helper)\n\
184 - Network/firewall issues\n\
185 \n\
186 Try running: git fetch --verbose\n\
187 Or check authentication with: git-wok test-auth",
188 tracking.remote,
189 self.work_dir.display()
190 )
191 })?;
192 },
193 Err(_) => {
194 return Ok(());
196 },
197 }
198
199 Ok(())
200 }
201
202 pub fn ensure_on_branch(&self, branch_name: &str) -> Result<()> {
203 if !self.is_worktree_clean()? {
204 bail!(
205 "Refusing to switch branches with uncommitted changes in `{}`",
206 self.work_dir.display()
207 );
208 }
209
210 if !self.git_repo.head_detached().with_context(|| {
211 format!(
212 "Cannot determine head state for repo at `{}`",
213 self.work_dir.display()
214 )
215 })? && let Ok(head) = self.git_repo.head()
216 && head.shorthand() == Some(branch_name)
217 {
218 return Ok(());
219 }
220
221 let local_ref = format!("refs/heads/{}", branch_name);
222 if self.git_repo.find_reference(&local_ref).is_ok() {
223 self.switch_forced(branch_name)?;
224 return Ok(());
225 }
226
227 let remote_name = self.get_remote_name_for_branch(branch_name)?;
228 if let Ok(mut remote) = self.git_repo.find_remote(&remote_name) {
229 let mut fetch_options = git2::FetchOptions::new();
230 fetch_options.remote_callbacks(self.remote_callbacks()?);
231 remote.fetch::<&str>(&[], Some(&mut fetch_options), None)?;
232 }
233
234 let remote_ref = format!("refs/remotes/{}/{}", remote_name, branch_name);
235 if let Ok(remote_oid) = self.git_repo.refname_to_id(&remote_ref) {
236 let remote_commit = self.git_repo.find_commit(remote_oid)?;
237 self.git_repo.branch(branch_name, &remote_commit, false)?;
238 let mut local_branch = self
239 .git_repo
240 .find_branch(branch_name, git2::BranchType::Local)?;
241 local_branch
242 .set_upstream(Some(&format!("{}/{}", remote_name, branch_name)))?;
243 self.switch(branch_name)?;
244 return Ok(());
245 }
246
247 let head = self.git_repo.head()?;
248 let current_commit = head.peel_to_commit()?;
249 self.git_repo.branch(branch_name, ¤t_commit, false)?;
250 self.switch(branch_name)?;
251 Ok(())
252 }
253
254 pub fn ensure_on_branch_existing_or_remote(
255 &self,
256 branch_name: &str,
257 create: bool,
258 ) -> Result<()> {
259 if !self.is_worktree_clean()? {
260 bail!(
261 "Refusing to switch branches with uncommitted changes in `{}`",
262 self.work_dir.display()
263 );
264 }
265
266 if !self.git_repo.head_detached().with_context(|| {
267 format!(
268 "Cannot determine head state for repo at `{}`",
269 self.work_dir.display()
270 )
271 })? && let Ok(head) = self.git_repo.head()
272 && head.shorthand() == Some(branch_name)
273 {
274 return Ok(());
275 }
276
277 let local_ref = format!("refs/heads/{}", branch_name);
278 if self.git_repo.find_reference(&local_ref).is_ok() {
279 self.switch(branch_name)?;
280 return Ok(());
281 }
282
283 let remote_name = self.get_remote_name_for_branch(branch_name)?;
284 if let Ok(mut remote) = self.git_repo.find_remote(&remote_name) {
285 let mut fetch_options = git2::FetchOptions::new();
286 fetch_options.remote_callbacks(self.remote_callbacks()?);
287 remote.fetch::<&str>(&[], Some(&mut fetch_options), None)?;
288 }
289
290 let remote_ref = format!("refs/remotes/{}/{}", remote_name, branch_name);
291 if let Ok(remote_oid) = self.git_repo.refname_to_id(&remote_ref) {
292 let remote_commit = self.git_repo.find_commit(remote_oid)?;
293 self.git_repo.branch(branch_name, &remote_commit, false)?;
294 let mut local_branch = self
295 .git_repo
296 .find_branch(branch_name, git2::BranchType::Local)?;
297 local_branch
298 .set_upstream(Some(&format!("{}/{}", remote_name, branch_name)))?;
299 self.switch_forced(branch_name)?;
300 return Ok(());
301 }
302
303 if create {
304 let head = self.git_repo.head()?;
305 let current_commit = head.peel_to_commit()?;
306 self.git_repo.branch(branch_name, ¤t_commit, false)?;
307 self.switch_forced(branch_name)?;
308 return Ok(());
309 }
310
311 bail!(
312 "Branch '{}' does not exist and --create not specified",
313 branch_name
314 );
315 }
316
317 fn rebase(
318 &self,
319 _branch_name: &str,
320 remote_commit: &git2::Commit,
321 ) -> Result<MergeResult> {
322 let _local_commit = self.git_repo.head()?.peel_to_commit()?;
323 let remote_oid = remote_commit.id();
324
325 let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
327
328 let signature = self.git_repo.signature()?;
330 let mut rebase = self.git_repo.rebase(
331 None, Some(&remote_annotated), None, None, )?;
336
337 let mut has_conflicts = false;
339 while let Some(op) = rebase.next() {
340 match op {
341 Ok(_rebase_op) => {
342 let index = self.git_repo.index()?;
344 if index.has_conflicts() {
345 has_conflicts = true;
346 break;
347 }
348
349 if rebase.commit(None, &signature, None).is_err() {
351 has_conflicts = true;
352 break;
353 }
354 },
355 Err(_) => {
356 has_conflicts = true;
357 break;
358 },
359 }
360 }
361
362 if has_conflicts {
363 return Ok(MergeResult::Conflicts);
365 }
366
367 rebase.finish(Some(&signature))?;
369
370 Ok(MergeResult::Rebased)
371 }
372
373 pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
374 self.fetch()?;
376
377 let tracking = match self.tracking_branch(branch_name)? {
379 Some(tracking) => tracking,
380 None => {
381 return Ok(MergeResult::UpToDate);
383 },
384 };
385
386 let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
388 {
389 Ok(oid) => oid,
390 Err(_) => {
391 return Ok(MergeResult::UpToDate);
393 },
394 };
395
396 let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
397 let local_commit = self.git_repo.head()?.peel_to_commit()?;
398
399 if local_commit.id() == remote_commit.id() {
401 return Ok(MergeResult::UpToDate);
402 }
403
404 if self
406 .git_repo
407 .graph_descendant_of(remote_commit.id(), local_commit.id())?
408 {
409 self.git_repo.reference(
411 &format!("refs/heads/{}", branch_name),
412 remote_commit.id(),
413 true,
414 &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
415 )?;
416 self.git_repo
417 .set_head(&format!("refs/heads/{}", branch_name))?;
418 let mut checkout = CheckoutBuilder::new();
419 checkout.force();
420 self.git_repo.checkout_head(Some(&mut checkout))?;
421 return Ok(MergeResult::FastForward);
422 }
423
424 let pull_strategy = self.get_pull_strategy(branch_name)?;
426
427 match pull_strategy {
428 PullStrategy::Rebase => {
429 self.rebase(branch_name, &remote_commit)
431 },
432 PullStrategy::Merge => {
433 self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
435 },
436 }
437 }
438
439 fn do_merge(
440 &self,
441 branch_name: &str,
442 local_commit: &git2::Commit,
443 remote_commit: &git2::Commit,
444 tracking: &TrackingBranch,
445 ) -> Result<MergeResult> {
446 let mut merge_opts = git2::MergeOptions::new();
448 merge_opts.fail_on_conflict(false); let _merge_result = self.git_repo.merge_commits(
451 local_commit,
452 remote_commit,
453 Some(&merge_opts),
454 )?;
455
456 let mut index = self.git_repo.index()?;
458 let has_conflicts = index.has_conflicts();
459
460 if !has_conflicts {
461 let signature = self.git_repo.signature()?;
463 let tree_id = index.write_tree()?;
464 let tree = self.git_repo.find_tree(tree_id)?;
465
466 self.git_repo.commit(
467 Some(&format!("refs/heads/{}", branch_name)),
468 &signature,
469 &signature,
470 &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
471 &tree,
472 &[local_commit, remote_commit],
473 )?;
474
475 self.git_repo.cleanup_state()?;
476
477 Ok(MergeResult::Merged)
478 } else {
479 Ok(MergeResult::Conflicts)
481 }
482 }
483
484 pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
485 if let Some(tracking) = self.tracking_branch(branch_name)? {
486 Ok(tracking.remote)
487 } else {
488 Ok("origin".to_string())
490 }
491 }
492
493 pub fn get_remote_comparison(
495 &self,
496 branch_name: &str,
497 ) -> Result<Option<RemoteComparison>> {
498 let tracking = match self.tracking_branch(branch_name)? {
500 Some(tracking) => tracking,
501 None => return Ok(None), };
503
504 let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
506 Ok(oid) => oid,
507 Err(_) => {
508 return Ok(Some(RemoteComparison::NoRemote));
510 },
511 };
512
513 let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
515
516 if local_oid == remote_oid {
518 return Ok(Some(RemoteComparison::UpToDate));
519 }
520
521 let (ahead, behind) =
523 self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
524
525 if ahead > 0 && behind > 0 {
526 Ok(Some(RemoteComparison::Diverged(ahead, behind)))
527 } else if ahead > 0 {
528 Ok(Some(RemoteComparison::Ahead(ahead)))
529 } else if behind > 0 {
530 Ok(Some(RemoteComparison::Behind(behind)))
531 } else {
532 Ok(Some(RemoteComparison::UpToDate))
533 }
534 }
535
536 pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
537 self.remote_callbacks_impl(false)
538 }
539
540 pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
541 self.remote_callbacks_impl(true)
542 }
543
544 fn remote_callbacks_impl(
545 &self,
546 verbose: bool,
547 ) -> Result<git2::RemoteCallbacks<'static>> {
548 let config = self.git_repo.config()?;
549
550 let mut callbacks = git2::RemoteCallbacks::new();
551 callbacks.credentials(move |url, username_from_url, allowed| {
552 if verbose {
553 eprintln!("DEBUG: Credential callback invoked");
554 eprintln!(" URL: {}", url);
555 eprintln!(" Username from URL: {:?}", username_from_url);
556 eprintln!(" Allowed types: {:?}", allowed);
557 }
558
559 if allowed.contains(git2::CredentialType::SSH_KEY) {
561 if let Some(username) = username_from_url {
562 if std::env::var("SSH_AUTH_SOCK").is_ok() {
564 if verbose {
565 eprintln!(
566 " Attempting: SSH key from agent for user '{}'",
567 username
568 );
569 }
570 match git2::Cred::ssh_key_from_agent(username) {
571 Ok(cred) => {
572 if verbose {
573 eprintln!(" SUCCESS: SSH key from agent");
574 }
575 return Ok(cred);
576 },
577 Err(e) => {
578 if verbose {
579 eprintln!(" FAILED: SSH key from agent - {}", e);
580 }
581 },
582 }
583 } else if verbose {
584 eprintln!(
585 " SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)"
586 );
587 }
588 } else if verbose {
589 eprintln!(" SKIPPED: SSH key from agent (no username provided)");
590 }
591
592 if let Some(username) = username_from_url
594 && let Ok(home) = std::env::var("HOME")
595 {
596 let key_paths = vec![
597 format!("{}/.ssh/id_ed25519", home),
598 format!("{}/.ssh/id_rsa", home),
599 format!("{}/.ssh/id_ecdsa", home),
600 ];
601
602 for key_path in key_paths {
603 if path::Path::new(&key_path).exists() {
604 if verbose {
605 eprintln!(" Attempting: SSH key file at {}", key_path);
606 }
607 match git2::Cred::ssh_key(
608 username,
609 None, path::Path::new(&key_path),
611 None, ) {
613 Ok(cred) => {
614 if verbose {
615 eprintln!(" SUCCESS: SSH key file");
616 }
617 return Ok(cred);
618 },
619 Err(e) => {
620 if verbose {
621 eprintln!(" FAILED: SSH key file - {}", e);
622 }
623 },
624 }
625 }
626 }
627 }
628 }
629
630 if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
632 || allowed.contains(git2::CredentialType::SSH_KEY)
633 || allowed.contains(git2::CredentialType::DEFAULT)
634 {
635 if verbose {
636 eprintln!(" Attempting: Credential helper");
637 }
638 match git2::Cred::credential_helper(&config, url, username_from_url) {
639 Ok(cred) => {
640 if verbose {
641 eprintln!(" SUCCESS: Credential helper");
642 }
643 return Ok(cred);
644 },
645 Err(e) => {
646 if verbose {
647 eprintln!(" FAILED: Credential helper - {}", e);
648 }
649 },
650 }
651 }
652
653 if allowed.contains(git2::CredentialType::USERNAME) {
655 let username = username_from_url.unwrap_or("git");
656 if verbose {
657 eprintln!(" Attempting: Username only ('{}')", username);
658 }
659 match git2::Cred::username(username) {
660 Ok(cred) => {
661 if verbose {
662 eprintln!(" SUCCESS: Username");
663 }
664 return Ok(cred);
665 },
666 Err(e) => {
667 if verbose {
668 eprintln!(" FAILED: Username - {}", e);
669 }
670 },
671 }
672 }
673
674 if verbose {
676 eprintln!(" Attempting: Default credentials");
677 }
678 match git2::Cred::default() {
679 Ok(cred) => {
680 if verbose {
681 eprintln!(" SUCCESS: Default credentials");
682 }
683 Ok(cred)
684 },
685 Err(e) => {
686 if verbose {
687 eprintln!(" FAILED: All credential methods exhausted");
688 eprintln!(" Last error: {}", e);
689 }
690 Err(e)
691 },
692 }
693 });
694
695 Ok(callbacks)
696 }
697
698 fn resolve_reference(&self, short_name: &str) -> Result<String> {
699 Ok(self
700 .git_repo
701 .resolve_reference_from_short_name(short_name)?
702 .name()
703 .with_context(|| {
704 format!(
705 "Cannot resolve head reference for repo at `{}`",
706 self.work_dir.display()
707 )
708 })?
709 .to_owned())
710 }
711
712 pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
713 let config = self.git_repo.config()?;
714
715 let remote_key = format!("branch.{}.remote", branch_name);
716 let merge_key = format!("branch.{}.merge", branch_name);
717
718 let remote = match config.get_string(&remote_key) {
719 Ok(name) => name,
720 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
721 Err(err) => return Err(err.into()),
722 };
723
724 let merge_ref = match config.get_string(&merge_key) {
725 Ok(name) => name,
726 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
727 Err(err) => return Err(err.into()),
728 };
729
730 let branch_short = merge_ref
731 .strip_prefix("refs/heads/")
732 .unwrap_or(&merge_ref)
733 .to_owned();
734
735 let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
736
737 Ok(Some(TrackingBranch { remote, remote_ref }))
738 }
739
740 fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
741 let config = self.git_repo.config()?;
742
743 let branch_rebase_key = format!("branch.{}.rebase", branch_name);
745 if let Ok(value) = config.get_string(&branch_rebase_key) {
746 return Ok(parse_rebase_config(&value));
747 }
748
749 if let Ok(value) = config.get_string("pull.rebase") {
751 return Ok(parse_rebase_config(&value));
752 }
753
754 if let Ok(value) = config.get_bool("pull.rebase") {
756 return Ok(if value {
757 PullStrategy::Rebase
758 } else {
759 PullStrategy::Merge
760 });
761 }
762
763 Ok(PullStrategy::Merge)
765 }
766
767 fn is_worktree_clean(&self) -> Result<bool> {
768 let mut status_options = StatusOptions::new();
769 status_options.include_ignored(false);
770 status_options.include_untracked(true);
771 let statuses = self.git_repo.statuses(Some(&mut status_options))?;
772 Ok(statuses.is_empty())
773 }
774}
775
776impl fmt::Debug for Repo {
777 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
778 f.debug_struct("Repo")
779 .field("work_dir", &self.work_dir)
780 .field("head", &self.head)
781 .field("subrepos", &self.subrepos)
782 .finish()
783 }
784}
785
786pub struct TrackingBranch {
787 pub remote: String,
788 pub remote_ref: String,
789}
790
791#[derive(Debug, Clone, PartialEq)]
792enum PullStrategy {
793 Merge,
794 Rebase,
795}
796
797fn parse_rebase_config(value: &str) -> PullStrategy {
798 match value.to_lowercase().as_str() {
799 "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
800 "false" => PullStrategy::Merge,
801 _ => PullStrategy::Merge, }
803}