1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9 UpToDate,
10 FastForward,
11 Merged,
12 Rebased,
13 Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18 UpToDate,
19 Ahead(usize),
20 Behind(usize),
21 Diverged(usize, usize),
22 NoRemote,
23}
24
25pub struct Repo {
26 pub git_repo: git2::Repository,
27 pub work_dir: path::PathBuf,
28 pub head: String,
29 pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33 pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34 let git_repo = git2::Repository::open(work_dir)
35 .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37 let head = match head_name {
38 Some(name) => String::from(name),
39 None => {
40 if git_repo.head_detached().with_context(|| {
41 format!(
42 "Cannot determine head state for repo at `{}`",
43 work_dir.display()
44 )
45 })? {
46 bail!(
47 "Cannot operate on a detached head for repo at `{}`",
48 work_dir.display()
49 )
50 }
51
52 String::from(git_repo.head().with_context(|| {
53 format!(
54 "Cannot find the head branch for repo at `{}`. Is it detached?",
55 work_dir.display()
56 )
57 })?.shorthand().with_context(|| {
58 format!(
59 "Cannot find a human readable representation of the head ref for repo at `{}`",
60 work_dir.display(),
61 )
62 })?)
63 },
64 };
65
66 let subrepos = git_repo
67 .submodules()
68 .with_context(|| {
69 format!(
70 "Cannot load submodules for repo at `{}`",
71 work_dir.display()
72 )
73 })?
74 .iter()
75 .map(|submodule| Repo::new(&work_dir.join(submodule.path()), None))
76 .collect::<Result<Vec<Repo>>>()?;
77
78 Ok(Repo {
79 git_repo,
80 work_dir: path::PathBuf::from(work_dir),
81 head,
82 subrepos,
83 })
84 }
85
86 pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87 self.subrepos
88 .iter()
89 .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90 }
91
92 pub fn sync(&self) -> Result<()> {
93 self.switch(&self.head)?;
94 Ok(())
95 }
96
97 pub fn switch(&self, head: &str) -> Result<()> {
98 self.git_repo.set_head(&self.resolve_reference(head)?)?;
99 self.git_repo.checkout_head(None)?;
100 Ok(())
101 }
102
103 pub fn fetch(&self) -> Result<()> {
104 let head_ref = self.git_repo.head()?;
106 let branch_name = head_ref.shorthand().with_context(|| {
107 format!(
108 "Cannot get branch name for repo at `{}`",
109 self.work_dir.display()
110 )
111 })?;
112
113 let tracking = match self.tracking_branch(branch_name)? {
114 Some(tracking) => tracking,
115 None => {
116 return Ok(());
118 },
119 };
120
121 match self.git_repo.find_remote(&tracking.remote) {
123 Ok(mut remote) => {
124 let mut fetch_options = git2::FetchOptions::new();
125 fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127 remote
128 .fetch::<&str>(&[], Some(&mut fetch_options), None)
129 .with_context(|| {
130 format!(
131 "Failed to fetch from remote '{}' for repo at `{}`\n\
132 \n\
133 Possible causes:\n\
134 - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
135 - SSH keys not properly configured in ~/.ssh/\n\
136 - Credential helper not configured (git config credential.helper)\n\
137 - Network/firewall issues\n\
138 \n\
139 Try running: git fetch --verbose\n\
140 Or check authentication with: git-wok test-auth",
141 tracking.remote,
142 self.work_dir.display()
143 )
144 })?;
145 },
146 Err(_) => {
147 return Ok(());
149 },
150 }
151
152 Ok(())
153 }
154
155 fn rebase(
156 &self,
157 _branch_name: &str,
158 remote_commit: &git2::Commit,
159 ) -> Result<MergeResult> {
160 let _local_commit = self.git_repo.head()?.peel_to_commit()?;
161 let remote_oid = remote_commit.id();
162
163 let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
165
166 let signature = self.git_repo.signature()?;
168 let mut rebase = self.git_repo.rebase(
169 None, Some(&remote_annotated), None, None, )?;
174
175 let mut has_conflicts = false;
177 while let Some(op) = rebase.next() {
178 match op {
179 Ok(_rebase_op) => {
180 let index = self.git_repo.index()?;
182 if index.has_conflicts() {
183 has_conflicts = true;
184 break;
185 }
186
187 if rebase.commit(None, &signature, None).is_err() {
189 has_conflicts = true;
190 break;
191 }
192 },
193 Err(_) => {
194 has_conflicts = true;
195 break;
196 },
197 }
198 }
199
200 if has_conflicts {
201 return Ok(MergeResult::Conflicts);
203 }
204
205 rebase.finish(Some(&signature))?;
207
208 Ok(MergeResult::Rebased)
209 }
210
211 pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
212 self.fetch()?;
214
215 let tracking = match self.tracking_branch(branch_name)? {
217 Some(tracking) => tracking,
218 None => {
219 return Ok(MergeResult::UpToDate);
221 },
222 };
223
224 let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
226 {
227 Ok(oid) => oid,
228 Err(_) => {
229 return Ok(MergeResult::UpToDate);
231 },
232 };
233
234 let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
235 let local_commit = self.git_repo.head()?.peel_to_commit()?;
236
237 if local_commit.id() == remote_commit.id() {
239 return Ok(MergeResult::UpToDate);
240 }
241
242 if self
244 .git_repo
245 .graph_descendant_of(remote_commit.id(), local_commit.id())?
246 {
247 self.git_repo.reference(
249 &format!("refs/heads/{}", branch_name),
250 remote_commit.id(),
251 true,
252 &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
253 )?;
254 self.git_repo
255 .set_head(&format!("refs/heads/{}", branch_name))?;
256 let mut checkout = CheckoutBuilder::new();
257 checkout.force();
258 self.git_repo.checkout_head(Some(&mut checkout))?;
259 return Ok(MergeResult::FastForward);
260 }
261
262 let pull_strategy = self.get_pull_strategy(branch_name)?;
264
265 match pull_strategy {
266 PullStrategy::Rebase => {
267 self.rebase(branch_name, &remote_commit)
269 },
270 PullStrategy::Merge => {
271 self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
273 },
274 }
275 }
276
277 fn do_merge(
278 &self,
279 branch_name: &str,
280 local_commit: &git2::Commit,
281 remote_commit: &git2::Commit,
282 tracking: &TrackingBranch,
283 ) -> Result<MergeResult> {
284 let mut merge_opts = git2::MergeOptions::new();
286 merge_opts.fail_on_conflict(false); let _merge_result = self.git_repo.merge_commits(
289 local_commit,
290 remote_commit,
291 Some(&merge_opts),
292 )?;
293
294 let mut index = self.git_repo.index()?;
296 let has_conflicts = index.has_conflicts();
297
298 if !has_conflicts {
299 let signature = self.git_repo.signature()?;
301 let tree_id = index.write_tree()?;
302 let tree = self.git_repo.find_tree(tree_id)?;
303
304 self.git_repo.commit(
305 Some(&format!("refs/heads/{}", branch_name)),
306 &signature,
307 &signature,
308 &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
309 &tree,
310 &[local_commit, remote_commit],
311 )?;
312
313 self.git_repo.cleanup_state()?;
314
315 Ok(MergeResult::Merged)
316 } else {
317 Ok(MergeResult::Conflicts)
319 }
320 }
321
322 pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
323 if let Some(tracking) = self.tracking_branch(branch_name)? {
324 Ok(tracking.remote)
325 } else {
326 Ok("origin".to_string())
328 }
329 }
330
331 pub fn get_remote_comparison(
333 &self,
334 branch_name: &str,
335 ) -> Result<Option<RemoteComparison>> {
336 let tracking = match self.tracking_branch(branch_name)? {
338 Some(tracking) => tracking,
339 None => return Ok(None), };
341
342 let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
344 Ok(oid) => oid,
345 Err(_) => {
346 return Ok(Some(RemoteComparison::NoRemote));
348 },
349 };
350
351 let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
353
354 if local_oid == remote_oid {
356 return Ok(Some(RemoteComparison::UpToDate));
357 }
358
359 let (ahead, behind) =
361 self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
362
363 if ahead > 0 && behind > 0 {
364 Ok(Some(RemoteComparison::Diverged(ahead, behind)))
365 } else if ahead > 0 {
366 Ok(Some(RemoteComparison::Ahead(ahead)))
367 } else if behind > 0 {
368 Ok(Some(RemoteComparison::Behind(behind)))
369 } else {
370 Ok(Some(RemoteComparison::UpToDate))
371 }
372 }
373
374 pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
375 self.remote_callbacks_impl(false)
376 }
377
378 pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
379 self.remote_callbacks_impl(true)
380 }
381
382 fn remote_callbacks_impl(
383 &self,
384 verbose: bool,
385 ) -> Result<git2::RemoteCallbacks<'static>> {
386 let config = self.git_repo.config()?;
387
388 let mut callbacks = git2::RemoteCallbacks::new();
389 callbacks.credentials(move |url, username_from_url, allowed| {
390 if verbose {
391 eprintln!("DEBUG: Credential callback invoked");
392 eprintln!(" URL: {}", url);
393 eprintln!(" Username from URL: {:?}", username_from_url);
394 eprintln!(" Allowed types: {:?}", allowed);
395 }
396
397 if allowed.contains(git2::CredentialType::SSH_KEY) {
399 if let Some(username) = username_from_url {
400 if std::env::var("SSH_AUTH_SOCK").is_ok() {
402 if verbose {
403 eprintln!(
404 " Attempting: SSH key from agent for user '{}'",
405 username
406 );
407 }
408 match git2::Cred::ssh_key_from_agent(username) {
409 Ok(cred) => {
410 if verbose {
411 eprintln!(" SUCCESS: SSH key from agent");
412 }
413 return Ok(cred);
414 },
415 Err(e) => {
416 if verbose {
417 eprintln!(" FAILED: SSH key from agent - {}", e);
418 }
419 },
420 }
421 } else if verbose {
422 eprintln!(
423 " SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)"
424 );
425 }
426 } else if verbose {
427 eprintln!(" SKIPPED: SSH key from agent (no username provided)");
428 }
429
430 if let Some(username) = username_from_url
432 && let Ok(home) = std::env::var("HOME")
433 {
434 let key_paths = vec![
435 format!("{}/.ssh/id_ed25519", home),
436 format!("{}/.ssh/id_rsa", home),
437 format!("{}/.ssh/id_ecdsa", home),
438 ];
439
440 for key_path in key_paths {
441 if path::Path::new(&key_path).exists() {
442 if verbose {
443 eprintln!(" Attempting: SSH key file at {}", key_path);
444 }
445 match git2::Cred::ssh_key(
446 username,
447 None, path::Path::new(&key_path),
449 None, ) {
451 Ok(cred) => {
452 if verbose {
453 eprintln!(" SUCCESS: SSH key file");
454 }
455 return Ok(cred);
456 },
457 Err(e) => {
458 if verbose {
459 eprintln!(" FAILED: SSH key file - {}", e);
460 }
461 },
462 }
463 }
464 }
465 }
466 }
467
468 if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
470 || allowed.contains(git2::CredentialType::SSH_KEY)
471 || allowed.contains(git2::CredentialType::DEFAULT)
472 {
473 if verbose {
474 eprintln!(" Attempting: Credential helper");
475 }
476 match git2::Cred::credential_helper(&config, url, username_from_url) {
477 Ok(cred) => {
478 if verbose {
479 eprintln!(" SUCCESS: Credential helper");
480 }
481 return Ok(cred);
482 },
483 Err(e) => {
484 if verbose {
485 eprintln!(" FAILED: Credential helper - {}", e);
486 }
487 },
488 }
489 }
490
491 if allowed.contains(git2::CredentialType::USERNAME) {
493 let username = username_from_url.unwrap_or("git");
494 if verbose {
495 eprintln!(" Attempting: Username only ('{}')", username);
496 }
497 match git2::Cred::username(username) {
498 Ok(cred) => {
499 if verbose {
500 eprintln!(" SUCCESS: Username");
501 }
502 return Ok(cred);
503 },
504 Err(e) => {
505 if verbose {
506 eprintln!(" FAILED: Username - {}", e);
507 }
508 },
509 }
510 }
511
512 if verbose {
514 eprintln!(" Attempting: Default credentials");
515 }
516 match git2::Cred::default() {
517 Ok(cred) => {
518 if verbose {
519 eprintln!(" SUCCESS: Default credentials");
520 }
521 Ok(cred)
522 },
523 Err(e) => {
524 if verbose {
525 eprintln!(" FAILED: All credential methods exhausted");
526 eprintln!(" Last error: {}", e);
527 }
528 Err(e)
529 },
530 }
531 });
532
533 Ok(callbacks)
534 }
535
536 fn resolve_reference(&self, short_name: &str) -> Result<String> {
537 Ok(self
538 .git_repo
539 .resolve_reference_from_short_name(short_name)?
540 .name()
541 .with_context(|| {
542 format!(
543 "Cannot resolve head reference for repo at `{}`",
544 self.work_dir.display()
545 )
546 })?
547 .to_owned())
548 }
549
550 pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
551 let config = self.git_repo.config()?;
552
553 let remote_key = format!("branch.{}.remote", branch_name);
554 let merge_key = format!("branch.{}.merge", branch_name);
555
556 let remote = match config.get_string(&remote_key) {
557 Ok(name) => name,
558 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
559 Err(err) => return Err(err.into()),
560 };
561
562 let merge_ref = match config.get_string(&merge_key) {
563 Ok(name) => name,
564 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
565 Err(err) => return Err(err.into()),
566 };
567
568 let branch_short = merge_ref
569 .strip_prefix("refs/heads/")
570 .unwrap_or(&merge_ref)
571 .to_owned();
572
573 let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
574
575 Ok(Some(TrackingBranch { remote, remote_ref }))
576 }
577
578 fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
579 let config = self.git_repo.config()?;
580
581 let branch_rebase_key = format!("branch.{}.rebase", branch_name);
583 if let Ok(value) = config.get_string(&branch_rebase_key) {
584 return Ok(parse_rebase_config(&value));
585 }
586
587 if let Ok(value) = config.get_string("pull.rebase") {
589 return Ok(parse_rebase_config(&value));
590 }
591
592 if let Ok(value) = config.get_bool("pull.rebase") {
594 return Ok(if value {
595 PullStrategy::Rebase
596 } else {
597 PullStrategy::Merge
598 });
599 }
600
601 Ok(PullStrategy::Merge)
603 }
604}
605
606impl fmt::Debug for Repo {
607 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
608 f.debug_struct("Repo")
609 .field("work_dir", &self.work_dir)
610 .field("head", &self.head)
611 .field("subrepos", &self.subrepos)
612 .finish()
613 }
614}
615
616pub struct TrackingBranch {
617 pub remote: String,
618 pub remote_ref: String,
619}
620
621#[derive(Debug, Clone, PartialEq)]
622enum PullStrategy {
623 Merge,
624 Rebase,
625}
626
627fn parse_rebase_config(value: &str) -> PullStrategy {
628 match value.to_lowercase().as_str() {
629 "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
630 "false" => PullStrategy::Merge,
631 _ => PullStrategy::Merge, }
633}