1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9 UpToDate,
10 FastForward,
11 Merged,
12 Rebased,
13 Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18 UpToDate,
19 Ahead(usize),
20 Behind(usize),
21 Diverged(usize, usize),
22 NoRemote,
23}
24
25pub struct Repo {
26 pub git_repo: git2::Repository,
27 pub work_dir: path::PathBuf,
28 pub head: String,
29 pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33 pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34 let git_repo = git2::Repository::open(work_dir)
35 .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37 let head = match head_name {
38 Some(name) => String::from(name),
39 None => {
40 if git_repo.head_detached().with_context(|| {
41 format!(
42 "Cannot determine head state for repo at `{}`",
43 work_dir.display()
44 )
45 })? {
46 bail!(
47 "Cannot operate on a detached head for repo at `{}`",
48 work_dir.display()
49 )
50 }
51
52 String::from(git_repo.head().with_context(|| {
53 format!(
54 "Cannot find the head branch for repo at `{}`. Is it detached?",
55 work_dir.display()
56 )
57 })?.shorthand().with_context(|| {
58 format!(
59 "Cannot find a human readable representation of the head ref for repo at `{}`",
60 work_dir.display(),
61 )
62 })?)
63 },
64 };
65
66 let subrepos = git_repo
67 .submodules()
68 .with_context(|| {
69 format!(
70 "Cannot load submodules for repo at `{}`",
71 work_dir.display()
72 )
73 })?
74 .iter()
75 .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
76 .collect::<Result<Vec<Repo>>>()?;
77
78 Ok(Repo {
79 git_repo,
80 work_dir: path::PathBuf::from(work_dir),
81 head,
82 subrepos,
83 })
84 }
85
86 pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87 self.subrepos
88 .iter()
89 .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90 }
91
92 pub fn sync(&self) -> Result<()> {
93 self.switch(&self.head)?;
94 Ok(())
95 }
96
97 pub fn switch(&self, head: &str) -> Result<()> {
98 self.git_repo.set_head(&self.resolve_reference(head)?)?;
99 self.git_repo.checkout_head(None)?;
100 Ok(())
101 }
102
103 pub fn fetch(&self) -> Result<()> {
104 let head_ref = self.git_repo.head()?;
106 let branch_name = head_ref.shorthand().with_context(|| {
107 format!(
108 "Cannot get branch name for repo at `{}`",
109 self.work_dir.display()
110 )
111 })?;
112
113 let tracking = match self.tracking_branch(branch_name)? {
114 Some(tracking) => tracking,
115 None => {
116 return Ok(());
118 },
119 };
120
121 match self.git_repo.find_remote(&tracking.remote) {
123 Ok(mut remote) => {
124 let mut fetch_options = git2::FetchOptions::new();
125 fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127 remote
128 .fetch::<&str>(&[], Some(&mut fetch_options), None)
129 .with_context(|| {
130 format!(
131 "Failed to fetch from remote '{}' for repo at `{}`\n\
132 \n\
133 Possible causes:\n\
134 - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
135 - SSH keys not properly configured in ~/.ssh/\n\
136 - Credential helper not configured (git config credential.helper)\n\
137 - Network/firewall issues\n\
138 \n\
139 Try running: git fetch --verbose\n\
140 Or check authentication with: git-wok test-auth",
141 tracking.remote,
142 self.work_dir.display()
143 )
144 })?;
145 },
146 Err(_) => {
147 return Ok(());
149 },
150 }
151
152 Ok(())
153 }
154
155 fn rebase(
156 &self,
157 _branch_name: &str,
158 remote_commit: &git2::Commit,
159 ) -> Result<MergeResult> {
160 let _local_commit = self.git_repo.head()?.peel_to_commit()?;
161 let remote_oid = remote_commit.id();
162
163 let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
165
166 let signature = self.git_repo.signature()?;
168 let mut rebase = self.git_repo.rebase(
169 None, Some(&remote_annotated), None, None, )?;
174
175 let mut has_conflicts = false;
177 while let Some(op) = rebase.next() {
178 match op {
179 Ok(_rebase_op) => {
180 let index = self.git_repo.index()?;
182 if index.has_conflicts() {
183 has_conflicts = true;
184 break;
185 }
186
187 if rebase.commit(None, &signature, None).is_err() {
189 has_conflicts = true;
190 break;
191 }
192 },
193 Err(_) => {
194 has_conflicts = true;
195 break;
196 },
197 }
198 }
199
200 if has_conflicts {
201 return Ok(MergeResult::Conflicts);
203 }
204
205 rebase.finish(Some(&signature))?;
207
208 Ok(MergeResult::Rebased)
209 }
210
211 pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
212 self.fetch()?;
214
215 let tracking = match self.tracking_branch(branch_name)? {
217 Some(tracking) => tracking,
218 None => {
219 return Ok(MergeResult::UpToDate);
221 },
222 };
223
224 let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
226 {
227 Ok(oid) => oid,
228 Err(_) => {
229 return Ok(MergeResult::UpToDate);
231 },
232 };
233
234 let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
235 let local_commit = self.git_repo.head()?.peel_to_commit()?;
236
237 if local_commit.id() == remote_commit.id() {
239 return Ok(MergeResult::UpToDate);
240 }
241
242 if self
244 .git_repo
245 .graph_descendant_of(remote_commit.id(), local_commit.id())?
246 {
247 self.git_repo.reference(
249 &format!("refs/heads/{}", branch_name),
250 remote_commit.id(),
251 true,
252 &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
253 )?;
254 self.git_repo
255 .set_head(&format!("refs/heads/{}", branch_name))?;
256 let mut checkout = CheckoutBuilder::new();
257 checkout.force();
258 self.git_repo.checkout_head(Some(&mut checkout))?;
259 return Ok(MergeResult::FastForward);
260 }
261
262 let pull_strategy = self.get_pull_strategy(branch_name)?;
264
265 match pull_strategy {
266 PullStrategy::Rebase => {
267 self.rebase(branch_name, &remote_commit)
269 },
270 PullStrategy::Merge => {
271 self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
273 },
274 }
275 }
276
277 fn do_merge(
278 &self,
279 branch_name: &str,
280 local_commit: &git2::Commit,
281 remote_commit: &git2::Commit,
282 tracking: &TrackingBranch,
283 ) -> Result<MergeResult> {
284 let mut merge_opts = git2::MergeOptions::new();
286 merge_opts.fail_on_conflict(false); let _merge_result = self.git_repo.merge_commits(
289 local_commit,
290 remote_commit,
291 Some(&merge_opts),
292 )?;
293
294 let mut index = self.git_repo.index()?;
296 let has_conflicts = index.has_conflicts();
297
298 if !has_conflicts {
299 let signature = self.git_repo.signature()?;
301 let tree_id = index.write_tree()?;
302 let tree = self.git_repo.find_tree(tree_id)?;
303
304 self.git_repo.commit(
305 Some(&format!("refs/heads/{}", branch_name)),
306 &signature,
307 &signature,
308 &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
309 &tree,
310 &[local_commit, remote_commit],
311 )?;
312
313 self.git_repo.cleanup_state()?;
314
315 Ok(MergeResult::Merged)
316 } else {
317 Ok(MergeResult::Conflicts)
319 }
320 }
321
322 pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
323 if let Some(tracking) = self.tracking_branch(branch_name)? {
324 Ok(tracking.remote)
325 } else {
326 Ok("origin".to_string())
328 }
329 }
330
331 pub fn get_remote_comparison(
333 &self,
334 branch_name: &str,
335 ) -> Result<Option<RemoteComparison>> {
336 let tracking = match self.tracking_branch(branch_name)? {
338 Some(tracking) => tracking,
339 None => return Ok(None), };
341
342 let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
344 Ok(oid) => oid,
345 Err(_) => {
346 return Ok(Some(RemoteComparison::NoRemote));
348 },
349 };
350
351 let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
353
354 if local_oid == remote_oid {
356 return Ok(Some(RemoteComparison::UpToDate));
357 }
358
359 let (ahead, behind) =
361 self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
362
363 if ahead > 0 && behind > 0 {
364 Ok(Some(RemoteComparison::Diverged(ahead, behind)))
365 } else if ahead > 0 {
366 Ok(Some(RemoteComparison::Ahead(ahead)))
367 } else if behind > 0 {
368 Ok(Some(RemoteComparison::Behind(behind)))
369 } else {
370 Ok(Some(RemoteComparison::UpToDate))
371 }
372 }
373
374 pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
375 self.remote_callbacks_impl(false)
376 }
377
378 pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
379 self.remote_callbacks_impl(true)
380 }
381
382 fn remote_callbacks_impl(&self, verbose: bool) -> Result<git2::RemoteCallbacks<'static>> {
383 let config = self.git_repo.config()?;
384
385 let mut callbacks = git2::RemoteCallbacks::new();
386 callbacks.credentials(move |url, username_from_url, allowed| {
387 if verbose {
388 eprintln!("DEBUG: Credential callback invoked");
389 eprintln!(" URL: {}", url);
390 eprintln!(" Username from URL: {:?}", username_from_url);
391 eprintln!(" Allowed types: {:?}", allowed);
392 }
393
394 if allowed.contains(git2::CredentialType::SSH_KEY) {
396 if let Some(username) = username_from_url {
397 if std::env::var("SSH_AUTH_SOCK").is_ok() {
399 if verbose {
400 eprintln!(" Attempting: SSH key from agent for user '{}'", username);
401 }
402 match git2::Cred::ssh_key_from_agent(username) {
403 Ok(cred) => {
404 if verbose {
405 eprintln!(" SUCCESS: SSH key from agent");
406 }
407 return Ok(cred);
408 }
409 Err(e) => {
410 if verbose {
411 eprintln!(" FAILED: SSH key from agent - {}", e);
412 }
413 }
414 }
415 } else if verbose {
416 eprintln!(" SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)");
417 }
418 } else if verbose {
419 eprintln!(" SKIPPED: SSH key from agent (no username provided)");
420 }
421
422 if let Some(username) = username_from_url
424 && let Ok(home) = std::env::var("HOME")
425 {
426 let key_paths = vec![
427 format!("{}/.ssh/id_ed25519", home),
428 format!("{}/.ssh/id_rsa", home),
429 format!("{}/.ssh/id_ecdsa", home),
430 ];
431
432 for key_path in key_paths {
433 if path::Path::new(&key_path).exists() {
434 if verbose {
435 eprintln!(" Attempting: SSH key file at {}", key_path);
436 }
437 match git2::Cred::ssh_key(
438 username,
439 None, path::Path::new(&key_path),
441 None, ) {
443 Ok(cred) => {
444 if verbose {
445 eprintln!(" SUCCESS: SSH key file");
446 }
447 return Ok(cred);
448 }
449 Err(e) => {
450 if verbose {
451 eprintln!(" FAILED: SSH key file - {}", e);
452 }
453 }
454 }
455 }
456 }
457 }
458 }
459
460 if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
462 || allowed.contains(git2::CredentialType::SSH_KEY)
463 || allowed.contains(git2::CredentialType::DEFAULT)
464 {
465 if verbose {
466 eprintln!(" Attempting: Credential helper");
467 }
468 match git2::Cred::credential_helper(&config, url, username_from_url) {
469 Ok(cred) => {
470 if verbose {
471 eprintln!(" SUCCESS: Credential helper");
472 }
473 return Ok(cred);
474 }
475 Err(e) => {
476 if verbose {
477 eprintln!(" FAILED: Credential helper - {}", e);
478 }
479 }
480 }
481 }
482
483 if allowed.contains(git2::CredentialType::USERNAME) {
485 let username = username_from_url.unwrap_or("git");
486 if verbose {
487 eprintln!(" Attempting: Username only ('{}')", username);
488 }
489 match git2::Cred::username(username) {
490 Ok(cred) => {
491 if verbose {
492 eprintln!(" SUCCESS: Username");
493 }
494 return Ok(cred);
495 }
496 Err(e) => {
497 if verbose {
498 eprintln!(" FAILED: Username - {}", e);
499 }
500 }
501 }
502 }
503
504 if verbose {
506 eprintln!(" Attempting: Default credentials");
507 }
508 match git2::Cred::default() {
509 Ok(cred) => {
510 if verbose {
511 eprintln!(" SUCCESS: Default credentials");
512 }
513 Ok(cred)
514 }
515 Err(e) => {
516 if verbose {
517 eprintln!(" FAILED: All credential methods exhausted");
518 eprintln!(" Last error: {}", e);
519 }
520 Err(e)
521 }
522 }
523 });
524
525 Ok(callbacks)
526 }
527
528 fn resolve_reference(&self, short_name: &str) -> Result<String> {
529 Ok(self
530 .git_repo
531 .resolve_reference_from_short_name(short_name)?
532 .name()
533 .with_context(|| {
534 format!(
535 "Cannot resolve head reference for repo at `{}`",
536 self.work_dir.display()
537 )
538 })?
539 .to_owned())
540 }
541
542 pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
543 let config = self.git_repo.config()?;
544
545 let remote_key = format!("branch.{}.remote", branch_name);
546 let merge_key = format!("branch.{}.merge", branch_name);
547
548 let remote = match config.get_string(&remote_key) {
549 Ok(name) => name,
550 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
551 Err(err) => return Err(err.into()),
552 };
553
554 let merge_ref = match config.get_string(&merge_key) {
555 Ok(name) => name,
556 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
557 Err(err) => return Err(err.into()),
558 };
559
560 let branch_short = merge_ref
561 .strip_prefix("refs/heads/")
562 .unwrap_or(&merge_ref)
563 .to_owned();
564
565 let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
566
567 Ok(Some(TrackingBranch { remote, remote_ref }))
568 }
569
570 fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
571 let config = self.git_repo.config()?;
572
573 let branch_rebase_key = format!("branch.{}.rebase", branch_name);
575 if let Ok(value) = config.get_string(&branch_rebase_key) {
576 return Ok(parse_rebase_config(&value));
577 }
578
579 if let Ok(value) = config.get_string("pull.rebase") {
581 return Ok(parse_rebase_config(&value));
582 }
583
584 if let Ok(value) = config.get_bool("pull.rebase") {
586 return Ok(if value {
587 PullStrategy::Rebase
588 } else {
589 PullStrategy::Merge
590 });
591 }
592
593 Ok(PullStrategy::Merge)
595 }
596}
597
598impl fmt::Debug for Repo {
599 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
600 f.debug_struct("Repo")
601 .field("work_dir", &self.work_dir)
602 .field("head", &self.head)
603 .field("subrepos", &self.subrepos)
604 .finish()
605 }
606}
607
608pub struct TrackingBranch {
609 pub remote: String,
610 pub remote_ref: String,
611}
612
613#[derive(Debug, Clone, PartialEq)]
614enum PullStrategy {
615 Merge,
616 Rebase,
617}
618
619fn parse_rebase_config(value: &str) -> PullStrategy {
620 match value.to_lowercase().as_str() {
621 "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
622 "false" => PullStrategy::Merge,
623 _ => PullStrategy::Merge, }
625}