1use std::{fmt, path};
2
3use anyhow::*;
4use git2::StatusOptions;
5use git2::build::CheckoutBuilder;
6use std::result::Result::Ok;
7
8#[derive(Debug, Clone, PartialEq)]
9pub enum MergeResult {
10 UpToDate,
11 FastForward,
12 Merged,
13 Rebased,
14 Conflicts,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum RemoteComparison {
19 UpToDate,
20 Ahead(usize),
21 Behind(usize),
22 Diverged(usize, usize),
23 NoRemote,
24}
25
26pub struct Repo {
27 pub git_repo: git2::Repository,
28 pub work_dir: path::PathBuf,
29 pub head: String,
30 pub subrepos: Vec<Repo>,
31}
32
33impl Repo {
34 pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
35 let git_repo = git2::Repository::open(work_dir)
36 .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
37
38 let head = match head_name {
39 Some(name) => String::from(name),
40 None => {
41 let is_detached = git_repo.head_detached().with_context(|| {
42 format!(
43 "Cannot determine head state for repo at `{}`",
44 work_dir.display()
45 )
46 })?;
47 if is_detached {
48 String::from("<detached>")
49 } else {
50 String::from(git_repo.head().with_context(|| {
51 format!(
52 "Cannot find the head branch for repo at `{}`. Is it detached?",
53 work_dir.display()
54 )
55 })?.shorthand().with_context(|| {
56 format!(
57 "Cannot find a human readable representation of the head ref for repo at `{}`",
58 work_dir.display(),
59 )
60 })?)
61 }
62 },
63 };
64
65 let subrepos = git_repo
66 .submodules()
67 .with_context(|| {
68 format!(
69 "Cannot load submodules for repo at `{}`",
70 work_dir.display()
71 )
72 })?
73 .iter()
74 .map(|submodule| Repo::new(&work_dir.join(submodule.path()), None))
75 .collect::<Result<Vec<Repo>>>()?;
76
77 Ok(Repo {
78 git_repo,
79 work_dir: path::PathBuf::from(work_dir),
80 head,
81 subrepos,
82 })
83 }
84
85 pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
86 self.subrepos
87 .iter()
88 .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
89 }
90
91 pub fn sync(&self) -> Result<()> {
92 self.switch(&self.head)?;
93 Ok(())
94 }
95
96 pub fn switch(&self, head: &str) -> Result<()> {
97 self.git_repo.set_head(&self.resolve_reference(head)?)?;
98 self.git_repo.checkout_head(None)?;
99 Ok(())
100 }
101
102 pub fn checkout_path_from_head(&self, path: &path::Path) -> Result<()> {
103 let mut checkout = CheckoutBuilder::new();
104 checkout.force().path(path);
105 self.git_repo.checkout_head(Some(&mut checkout))?;
106 Ok(())
107 }
108
109 fn switch_forced(&self, head: &str) -> Result<()> {
110 self.git_repo.set_head(&self.resolve_reference(head)?)?;
111 let mut checkout = CheckoutBuilder::new();
112 checkout.force();
113 self.git_repo.checkout_head(Some(&mut checkout))?;
114 Ok(())
115 }
116
117 pub fn fetch(&self) -> Result<()> {
118 if self.git_repo.head_detached().with_context(|| {
119 format!(
120 "Cannot determine head state for repo at `{}`",
121 self.work_dir.display()
122 )
123 })? {
124 return Ok(());
125 }
126
127 let head_ref = self.git_repo.head()?;
129 let branch_name = head_ref.shorthand().with_context(|| {
130 format!(
131 "Cannot get branch name for repo at `{}`",
132 self.work_dir.display()
133 )
134 })?;
135
136 let tracking = match self.tracking_branch(branch_name)? {
137 Some(tracking) => tracking,
138 None => {
139 return Ok(());
141 },
142 };
143
144 match self.git_repo.find_remote(&tracking.remote) {
146 Ok(mut remote) => {
147 let mut fetch_options = git2::FetchOptions::new();
148 fetch_options.remote_callbacks(self.remote_callbacks()?);
149
150 remote
151 .fetch::<&str>(&[], Some(&mut fetch_options), None)
152 .with_context(|| {
153 format!(
154 "Failed to fetch from remote '{}' for repo at `{}`\n\
155 \n\
156 Possible causes:\n\
157 - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
158 - SSH keys not properly configured in ~/.ssh/\n\
159 - Credential helper not configured (git config credential.helper)\n\
160 - Network/firewall issues\n\
161 \n\
162 Try running: git fetch --verbose\n\
163 Or check authentication with: git-wok test-auth",
164 tracking.remote,
165 self.work_dir.display()
166 )
167 })?;
168 },
169 Err(_) => {
170 return Ok(());
172 },
173 }
174
175 Ok(())
176 }
177
178 pub fn ensure_on_branch(&self, branch_name: &str) -> Result<()> {
179 if !self.is_worktree_clean()? {
180 bail!(
181 "Refusing to switch branches with uncommitted changes in `{}`",
182 self.work_dir.display()
183 );
184 }
185
186 if !self.git_repo.head_detached().with_context(|| {
187 format!(
188 "Cannot determine head state for repo at `{}`",
189 self.work_dir.display()
190 )
191 })? && let Ok(head) = self.git_repo.head()
192 && head.shorthand() == Some(branch_name)
193 {
194 return Ok(());
195 }
196
197 let local_ref = format!("refs/heads/{}", branch_name);
198 if self.git_repo.find_reference(&local_ref).is_ok() {
199 self.switch_forced(branch_name)?;
200 return Ok(());
201 }
202
203 let remote_name = self.get_remote_name_for_branch(branch_name)?;
204 if let Ok(mut remote) = self.git_repo.find_remote(&remote_name) {
205 let mut fetch_options = git2::FetchOptions::new();
206 fetch_options.remote_callbacks(self.remote_callbacks()?);
207 remote.fetch::<&str>(&[], Some(&mut fetch_options), None)?;
208 }
209
210 let remote_ref = format!("refs/remotes/{}/{}", remote_name, branch_name);
211 if let Ok(remote_oid) = self.git_repo.refname_to_id(&remote_ref) {
212 let remote_commit = self.git_repo.find_commit(remote_oid)?;
213 self.git_repo.branch(branch_name, &remote_commit, false)?;
214 let mut local_branch = self
215 .git_repo
216 .find_branch(branch_name, git2::BranchType::Local)?;
217 local_branch
218 .set_upstream(Some(&format!("{}/{}", remote_name, branch_name)))?;
219 self.switch(branch_name)?;
220 return Ok(());
221 }
222
223 let head = self.git_repo.head()?;
224 let current_commit = head.peel_to_commit()?;
225 self.git_repo.branch(branch_name, ¤t_commit, false)?;
226 self.switch(branch_name)?;
227 Ok(())
228 }
229
230 pub fn ensure_on_branch_existing_or_remote(
231 &self,
232 branch_name: &str,
233 create: bool,
234 ) -> Result<()> {
235 if !self.is_worktree_clean()? {
236 bail!(
237 "Refusing to switch branches with uncommitted changes in `{}`",
238 self.work_dir.display()
239 );
240 }
241
242 if !self.git_repo.head_detached().with_context(|| {
243 format!(
244 "Cannot determine head state for repo at `{}`",
245 self.work_dir.display()
246 )
247 })? && let Ok(head) = self.git_repo.head()
248 && head.shorthand() == Some(branch_name)
249 {
250 return Ok(());
251 }
252
253 let local_ref = format!("refs/heads/{}", branch_name);
254 if self.git_repo.find_reference(&local_ref).is_ok() {
255 self.switch(branch_name)?;
256 return Ok(());
257 }
258
259 let remote_name = self.get_remote_name_for_branch(branch_name)?;
260 if let Ok(mut remote) = self.git_repo.find_remote(&remote_name) {
261 let mut fetch_options = git2::FetchOptions::new();
262 fetch_options.remote_callbacks(self.remote_callbacks()?);
263 remote.fetch::<&str>(&[], Some(&mut fetch_options), None)?;
264 }
265
266 let remote_ref = format!("refs/remotes/{}/{}", remote_name, branch_name);
267 if let Ok(remote_oid) = self.git_repo.refname_to_id(&remote_ref) {
268 let remote_commit = self.git_repo.find_commit(remote_oid)?;
269 self.git_repo.branch(branch_name, &remote_commit, false)?;
270 let mut local_branch = self
271 .git_repo
272 .find_branch(branch_name, git2::BranchType::Local)?;
273 local_branch
274 .set_upstream(Some(&format!("{}/{}", remote_name, branch_name)))?;
275 self.switch_forced(branch_name)?;
276 return Ok(());
277 }
278
279 if create {
280 let head = self.git_repo.head()?;
281 let current_commit = head.peel_to_commit()?;
282 self.git_repo.branch(branch_name, ¤t_commit, false)?;
283 self.switch_forced(branch_name)?;
284 return Ok(());
285 }
286
287 bail!(
288 "Branch '{}' does not exist and --create not specified",
289 branch_name
290 );
291 }
292
293 fn rebase(
294 &self,
295 _branch_name: &str,
296 remote_commit: &git2::Commit,
297 ) -> Result<MergeResult> {
298 let _local_commit = self.git_repo.head()?.peel_to_commit()?;
299 let remote_oid = remote_commit.id();
300
301 let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
303
304 let signature = self.git_repo.signature()?;
306 let mut rebase = self.git_repo.rebase(
307 None, Some(&remote_annotated), None, None, )?;
312
313 let mut has_conflicts = false;
315 while let Some(op) = rebase.next() {
316 match op {
317 Ok(_rebase_op) => {
318 let index = self.git_repo.index()?;
320 if index.has_conflicts() {
321 has_conflicts = true;
322 break;
323 }
324
325 if rebase.commit(None, &signature, None).is_err() {
327 has_conflicts = true;
328 break;
329 }
330 },
331 Err(_) => {
332 has_conflicts = true;
333 break;
334 },
335 }
336 }
337
338 if has_conflicts {
339 return Ok(MergeResult::Conflicts);
341 }
342
343 rebase.finish(Some(&signature))?;
345
346 Ok(MergeResult::Rebased)
347 }
348
349 pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
350 self.fetch()?;
352
353 let tracking = match self.tracking_branch(branch_name)? {
355 Some(tracking) => tracking,
356 None => {
357 return Ok(MergeResult::UpToDate);
359 },
360 };
361
362 let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
364 {
365 Ok(oid) => oid,
366 Err(_) => {
367 return Ok(MergeResult::UpToDate);
369 },
370 };
371
372 let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
373 let local_commit = self.git_repo.head()?.peel_to_commit()?;
374
375 if local_commit.id() == remote_commit.id() {
377 return Ok(MergeResult::UpToDate);
378 }
379
380 if self
382 .git_repo
383 .graph_descendant_of(remote_commit.id(), local_commit.id())?
384 {
385 self.git_repo.reference(
387 &format!("refs/heads/{}", branch_name),
388 remote_commit.id(),
389 true,
390 &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
391 )?;
392 self.git_repo
393 .set_head(&format!("refs/heads/{}", branch_name))?;
394 let mut checkout = CheckoutBuilder::new();
395 checkout.force();
396 self.git_repo.checkout_head(Some(&mut checkout))?;
397 return Ok(MergeResult::FastForward);
398 }
399
400 let pull_strategy = self.get_pull_strategy(branch_name)?;
402
403 match pull_strategy {
404 PullStrategy::Rebase => {
405 self.rebase(branch_name, &remote_commit)
407 },
408 PullStrategy::Merge => {
409 self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
411 },
412 }
413 }
414
415 fn do_merge(
416 &self,
417 branch_name: &str,
418 local_commit: &git2::Commit,
419 remote_commit: &git2::Commit,
420 tracking: &TrackingBranch,
421 ) -> Result<MergeResult> {
422 let mut merge_opts = git2::MergeOptions::new();
424 merge_opts.fail_on_conflict(false); let _merge_result = self.git_repo.merge_commits(
427 local_commit,
428 remote_commit,
429 Some(&merge_opts),
430 )?;
431
432 let mut index = self.git_repo.index()?;
434 let has_conflicts = index.has_conflicts();
435
436 if !has_conflicts {
437 let signature = self.git_repo.signature()?;
439 let tree_id = index.write_tree()?;
440 let tree = self.git_repo.find_tree(tree_id)?;
441
442 self.git_repo.commit(
443 Some(&format!("refs/heads/{}", branch_name)),
444 &signature,
445 &signature,
446 &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
447 &tree,
448 &[local_commit, remote_commit],
449 )?;
450
451 self.git_repo.cleanup_state()?;
452
453 Ok(MergeResult::Merged)
454 } else {
455 Ok(MergeResult::Conflicts)
457 }
458 }
459
460 pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
461 if let Some(tracking) = self.tracking_branch(branch_name)? {
462 Ok(tracking.remote)
463 } else {
464 Ok("origin".to_string())
466 }
467 }
468
469 pub fn get_remote_comparison(
471 &self,
472 branch_name: &str,
473 ) -> Result<Option<RemoteComparison>> {
474 let tracking = match self.tracking_branch(branch_name)? {
476 Some(tracking) => tracking,
477 None => return Ok(None), };
479
480 let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
482 Ok(oid) => oid,
483 Err(_) => {
484 return Ok(Some(RemoteComparison::NoRemote));
486 },
487 };
488
489 let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
491
492 if local_oid == remote_oid {
494 return Ok(Some(RemoteComparison::UpToDate));
495 }
496
497 let (ahead, behind) =
499 self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
500
501 if ahead > 0 && behind > 0 {
502 Ok(Some(RemoteComparison::Diverged(ahead, behind)))
503 } else if ahead > 0 {
504 Ok(Some(RemoteComparison::Ahead(ahead)))
505 } else if behind > 0 {
506 Ok(Some(RemoteComparison::Behind(behind)))
507 } else {
508 Ok(Some(RemoteComparison::UpToDate))
509 }
510 }
511
512 pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
513 self.remote_callbacks_impl(false)
514 }
515
516 pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
517 self.remote_callbacks_impl(true)
518 }
519
520 fn remote_callbacks_impl(
521 &self,
522 verbose: bool,
523 ) -> Result<git2::RemoteCallbacks<'static>> {
524 let config = self.git_repo.config()?;
525
526 let mut callbacks = git2::RemoteCallbacks::new();
527 callbacks.credentials(move |url, username_from_url, allowed| {
528 if verbose {
529 eprintln!("DEBUG: Credential callback invoked");
530 eprintln!(" URL: {}", url);
531 eprintln!(" Username from URL: {:?}", username_from_url);
532 eprintln!(" Allowed types: {:?}", allowed);
533 }
534
535 if allowed.contains(git2::CredentialType::SSH_KEY) {
537 if let Some(username) = username_from_url {
538 if std::env::var("SSH_AUTH_SOCK").is_ok() {
540 if verbose {
541 eprintln!(
542 " Attempting: SSH key from agent for user '{}'",
543 username
544 );
545 }
546 match git2::Cred::ssh_key_from_agent(username) {
547 Ok(cred) => {
548 if verbose {
549 eprintln!(" SUCCESS: SSH key from agent");
550 }
551 return Ok(cred);
552 },
553 Err(e) => {
554 if verbose {
555 eprintln!(" FAILED: SSH key from agent - {}", e);
556 }
557 },
558 }
559 } else if verbose {
560 eprintln!(
561 " SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)"
562 );
563 }
564 } else if verbose {
565 eprintln!(" SKIPPED: SSH key from agent (no username provided)");
566 }
567
568 if let Some(username) = username_from_url
570 && let Ok(home) = std::env::var("HOME")
571 {
572 let key_paths = vec![
573 format!("{}/.ssh/id_ed25519", home),
574 format!("{}/.ssh/id_rsa", home),
575 format!("{}/.ssh/id_ecdsa", home),
576 ];
577
578 for key_path in key_paths {
579 if path::Path::new(&key_path).exists() {
580 if verbose {
581 eprintln!(" Attempting: SSH key file at {}", key_path);
582 }
583 match git2::Cred::ssh_key(
584 username,
585 None, path::Path::new(&key_path),
587 None, ) {
589 Ok(cred) => {
590 if verbose {
591 eprintln!(" SUCCESS: SSH key file");
592 }
593 return Ok(cred);
594 },
595 Err(e) => {
596 if verbose {
597 eprintln!(" FAILED: SSH key file - {}", e);
598 }
599 },
600 }
601 }
602 }
603 }
604 }
605
606 if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
608 || allowed.contains(git2::CredentialType::SSH_KEY)
609 || allowed.contains(git2::CredentialType::DEFAULT)
610 {
611 if verbose {
612 eprintln!(" Attempting: Credential helper");
613 }
614 match git2::Cred::credential_helper(&config, url, username_from_url) {
615 Ok(cred) => {
616 if verbose {
617 eprintln!(" SUCCESS: Credential helper");
618 }
619 return Ok(cred);
620 },
621 Err(e) => {
622 if verbose {
623 eprintln!(" FAILED: Credential helper - {}", e);
624 }
625 },
626 }
627 }
628
629 if allowed.contains(git2::CredentialType::USERNAME) {
631 let username = username_from_url.unwrap_or("git");
632 if verbose {
633 eprintln!(" Attempting: Username only ('{}')", username);
634 }
635 match git2::Cred::username(username) {
636 Ok(cred) => {
637 if verbose {
638 eprintln!(" SUCCESS: Username");
639 }
640 return Ok(cred);
641 },
642 Err(e) => {
643 if verbose {
644 eprintln!(" FAILED: Username - {}", e);
645 }
646 },
647 }
648 }
649
650 if verbose {
652 eprintln!(" Attempting: Default credentials");
653 }
654 match git2::Cred::default() {
655 Ok(cred) => {
656 if verbose {
657 eprintln!(" SUCCESS: Default credentials");
658 }
659 Ok(cred)
660 },
661 Err(e) => {
662 if verbose {
663 eprintln!(" FAILED: All credential methods exhausted");
664 eprintln!(" Last error: {}", e);
665 }
666 Err(e)
667 },
668 }
669 });
670
671 Ok(callbacks)
672 }
673
674 fn resolve_reference(&self, short_name: &str) -> Result<String> {
675 Ok(self
676 .git_repo
677 .resolve_reference_from_short_name(short_name)?
678 .name()
679 .with_context(|| {
680 format!(
681 "Cannot resolve head reference for repo at `{}`",
682 self.work_dir.display()
683 )
684 })?
685 .to_owned())
686 }
687
688 pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
689 let config = self.git_repo.config()?;
690
691 let remote_key = format!("branch.{}.remote", branch_name);
692 let merge_key = format!("branch.{}.merge", branch_name);
693
694 let remote = match config.get_string(&remote_key) {
695 Ok(name) => name,
696 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
697 Err(err) => return Err(err.into()),
698 };
699
700 let merge_ref = match config.get_string(&merge_key) {
701 Ok(name) => name,
702 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
703 Err(err) => return Err(err.into()),
704 };
705
706 let branch_short = merge_ref
707 .strip_prefix("refs/heads/")
708 .unwrap_or(&merge_ref)
709 .to_owned();
710
711 let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
712
713 Ok(Some(TrackingBranch { remote, remote_ref }))
714 }
715
716 fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
717 let config = self.git_repo.config()?;
718
719 let branch_rebase_key = format!("branch.{}.rebase", branch_name);
721 if let Ok(value) = config.get_string(&branch_rebase_key) {
722 return Ok(parse_rebase_config(&value));
723 }
724
725 if let Ok(value) = config.get_string("pull.rebase") {
727 return Ok(parse_rebase_config(&value));
728 }
729
730 if let Ok(value) = config.get_bool("pull.rebase") {
732 return Ok(if value {
733 PullStrategy::Rebase
734 } else {
735 PullStrategy::Merge
736 });
737 }
738
739 Ok(PullStrategy::Merge)
741 }
742
743 fn is_worktree_clean(&self) -> Result<bool> {
744 let mut status_options = StatusOptions::new();
745 status_options.include_ignored(false);
746 status_options.include_untracked(true);
747 let statuses = self.git_repo.statuses(Some(&mut status_options))?;
748 Ok(statuses.is_empty())
749 }
750}
751
752impl fmt::Debug for Repo {
753 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
754 f.debug_struct("Repo")
755 .field("work_dir", &self.work_dir)
756 .field("head", &self.head)
757 .field("subrepos", &self.subrepos)
758 .finish()
759 }
760}
761
762pub struct TrackingBranch {
763 pub remote: String,
764 pub remote_ref: String,
765}
766
767#[derive(Debug, Clone, PartialEq)]
768enum PullStrategy {
769 Merge,
770 Rebase,
771}
772
773fn parse_rebase_config(value: &str) -> PullStrategy {
774 match value.to_lowercase().as_str() {
775 "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
776 "false" => PullStrategy::Merge,
777 _ => PullStrategy::Merge, }
779}