git_wok/
repo.rs

1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9    UpToDate,
10    FastForward,
11    Merged,
12    Rebased,
13    Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18    UpToDate,
19    Ahead(usize),
20    Behind(usize),
21    Diverged(usize, usize),
22    NoRemote,
23}
24
25pub struct Repo {
26    pub git_repo: git2::Repository,
27    pub work_dir: path::PathBuf,
28    pub head: String,
29    pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33    pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34        let git_repo = git2::Repository::open(work_dir)
35            .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37        let head = match head_name {
38            Some(name) => String::from(name),
39            None => {
40                if git_repo.head_detached().with_context(|| {
41                    format!(
42                        "Cannot determine head state for repo at `{}`",
43                        work_dir.display()
44                    )
45                })? {
46                    bail!(
47                        "Cannot operate on a detached head for repo at `{}`",
48                        work_dir.display()
49                    )
50                }
51
52                String::from(git_repo.head().with_context(|| {
53                    format!(
54                        "Cannot find the head branch for repo at `{}`. Is it detached?",
55                        work_dir.display()
56                    )
57                })?.shorthand().with_context(|| {
58                    format!(
59                        "Cannot find a human readable representation of the head ref for repo at `{}`",
60                        work_dir.display(),
61                    )
62                })?)
63            },
64        };
65
66        let subrepos = git_repo
67            .submodules()
68            .with_context(|| {
69                format!(
70                    "Cannot load submodules for repo at `{}`",
71                    work_dir.display()
72                )
73            })?
74            .iter()
75            .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
76            .collect::<Result<Vec<Repo>>>()?;
77
78        Ok(Repo {
79            git_repo,
80            work_dir: path::PathBuf::from(work_dir),
81            head,
82            subrepos,
83        })
84    }
85
86    pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87        self.subrepos
88            .iter()
89            .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90    }
91
92    pub fn sync(&self) -> Result<()> {
93        self.switch(&self.head)?;
94        Ok(())
95    }
96
97    pub fn switch(&self, head: &str) -> Result<()> {
98        self.git_repo.set_head(&self.resolve_reference(head)?)?;
99        self.git_repo.checkout_head(None)?;
100        Ok(())
101    }
102
103    pub fn fetch(&self) -> Result<()> {
104        // Get the remote for the current branch
105        let head_ref = self.git_repo.head()?;
106        let branch_name = head_ref.shorthand().with_context(|| {
107            format!(
108                "Cannot get branch name for repo at `{}`",
109                self.work_dir.display()
110            )
111        })?;
112
113        let tracking = match self.tracking_branch(branch_name)? {
114            Some(tracking) => tracking,
115            None => {
116                // No upstream configured, skip fetch
117                return Ok(());
118            },
119        };
120
121        // Check if remote exists
122        match self.git_repo.find_remote(&tracking.remote) {
123            Ok(mut remote) => {
124                let mut fetch_options = git2::FetchOptions::new();
125                fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127                remote
128                    .fetch::<&str>(&[], Some(&mut fetch_options), None)
129                    .with_context(|| {
130                        format!(
131                            "Failed to fetch from remote '{}' for repo at `{}`\n\
132                            \n\
133                            Possible causes:\n\
134                            - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
135                            - SSH keys not properly configured in ~/.ssh/\n\
136                            - Credential helper not configured (git config credential.helper)\n\
137                            - Network/firewall issues\n\
138                            \n\
139                            Try running: git fetch --verbose\n\
140                            Or check authentication with: git-wok test-auth",
141                            tracking.remote,
142                            self.work_dir.display()
143                        )
144                    })?;
145            },
146            Err(_) => {
147                // No remote configured, skip fetch
148                return Ok(());
149            },
150        }
151
152        Ok(())
153    }
154
155    fn rebase(
156        &self,
157        _branch_name: &str,
158        remote_commit: &git2::Commit,
159    ) -> Result<MergeResult> {
160        let _local_commit = self.git_repo.head()?.peel_to_commit()?;
161        let remote_oid = remote_commit.id();
162
163        // Prepare annotated commit for rebase
164        let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
165
166        // Initialize rebase operation
167        let signature = self.git_repo.signature()?;
168        let mut rebase = self.git_repo.rebase(
169            None,                    // branch to rebase (None = HEAD)
170            Some(&remote_annotated), // upstream
171            None,                    // onto (None = upstream)
172            None,                    // options
173        )?;
174
175        // Process each commit in the rebase
176        let mut has_conflicts = false;
177        while let Some(op) = rebase.next() {
178            match op {
179                Ok(_rebase_op) => {
180                    // Check for conflicts
181                    let index = self.git_repo.index()?;
182                    if index.has_conflicts() {
183                        has_conflicts = true;
184                        break;
185                    }
186
187                    // Commit the rebased changes
188                    if rebase.commit(None, &signature, None).is_err() {
189                        has_conflicts = true;
190                        break;
191                    }
192                },
193                Err(_) => {
194                    has_conflicts = true;
195                    break;
196                },
197            }
198        }
199
200        if has_conflicts {
201            // Leave repository in state with conflicts for user to resolve
202            return Ok(MergeResult::Conflicts);
203        }
204
205        // Finish the rebase
206        rebase.finish(Some(&signature))?;
207
208        Ok(MergeResult::Rebased)
209    }
210
211    pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
212        // First, fetch the latest changes
213        self.fetch()?;
214
215        // Resolve the tracking branch reference
216        let tracking = match self.tracking_branch(branch_name)? {
217            Some(tracking) => tracking,
218            None => {
219                // No upstream configured, treat as up to date
220                return Ok(MergeResult::UpToDate);
221            },
222        };
223
224        // Check if remote branch exists
225        let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
226        {
227            Ok(oid) => oid,
228            Err(_) => {
229                // No remote branch, just return up to date
230                return Ok(MergeResult::UpToDate);
231            },
232        };
233
234        let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
235        let local_commit = self.git_repo.head()?.peel_to_commit()?;
236
237        // Check if we're already up to date
238        if local_commit.id() == remote_commit.id() {
239            return Ok(MergeResult::UpToDate);
240        }
241
242        // Check if we can fast-forward (works for both merge and rebase)
243        if self
244            .git_repo
245            .graph_descendant_of(remote_commit.id(), local_commit.id())?
246        {
247            // Fast-forward merge
248            self.git_repo.reference(
249                &format!("refs/heads/{}", branch_name),
250                remote_commit.id(),
251                true,
252                &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
253            )?;
254            self.git_repo
255                .set_head(&format!("refs/heads/{}", branch_name))?;
256            let mut checkout = CheckoutBuilder::new();
257            checkout.force();
258            self.git_repo.checkout_head(Some(&mut checkout))?;
259            return Ok(MergeResult::FastForward);
260        }
261
262        // Determine pull strategy from git config
263        let pull_strategy = self.get_pull_strategy(branch_name)?;
264
265        match pull_strategy {
266            PullStrategy::Rebase => {
267                // Perform rebase
268                self.rebase(branch_name, &remote_commit)
269            },
270            PullStrategy::Merge => {
271                // Perform merge (existing logic)
272                self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
273            },
274        }
275    }
276
277    fn do_merge(
278        &self,
279        branch_name: &str,
280        local_commit: &git2::Commit,
281        remote_commit: &git2::Commit,
282        tracking: &TrackingBranch,
283    ) -> Result<MergeResult> {
284        // Perform a merge
285        let mut merge_opts = git2::MergeOptions::new();
286        merge_opts.fail_on_conflict(false); // Don't fail on conflicts, we'll handle them
287
288        let _merge_result = self.git_repo.merge_commits(
289            local_commit,
290            remote_commit,
291            Some(&merge_opts),
292        )?;
293
294        // Check if there are conflicts by examining the index
295        let mut index = self.git_repo.index()?;
296        let has_conflicts = index.has_conflicts();
297
298        if !has_conflicts {
299            // No conflicts, merge was successful
300            let signature = self.git_repo.signature()?;
301            let tree_id = index.write_tree()?;
302            let tree = self.git_repo.find_tree(tree_id)?;
303
304            self.git_repo.commit(
305                Some(&format!("refs/heads/{}", branch_name)),
306                &signature,
307                &signature,
308                &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
309                &tree,
310                &[local_commit, remote_commit],
311            )?;
312
313            self.git_repo.cleanup_state()?;
314
315            Ok(MergeResult::Merged)
316        } else {
317            // There are conflicts
318            Ok(MergeResult::Conflicts)
319        }
320    }
321
322    pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
323        if let Some(tracking) = self.tracking_branch(branch_name)? {
324            Ok(tracking.remote)
325        } else {
326            // Fall back to origin if no tracking branch is configured
327            Ok("origin".to_string())
328        }
329    }
330
331    /// Get the ahead/behind count relative to the remote tracking branch
332    pub fn get_remote_comparison(
333        &self,
334        branch_name: &str,
335    ) -> Result<Option<RemoteComparison>> {
336        // Get the tracking branch info
337        let tracking = match self.tracking_branch(branch_name)? {
338            Some(tracking) => tracking,
339            None => return Ok(None), // No tracking branch configured
340        };
341
342        // Check if remote branch exists
343        let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
344            Ok(oid) => oid,
345            Err(_) => {
346                // Remote branch doesn't exist
347                return Ok(Some(RemoteComparison::NoRemote));
348            },
349        };
350
351        // Get local branch OID
352        let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
353
354        // If they're the same, we're up to date
355        if local_oid == remote_oid {
356            return Ok(Some(RemoteComparison::UpToDate));
357        }
358
359        // Calculate ahead/behind using git's graph functions
360        let (ahead, behind) =
361            self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
362
363        if ahead > 0 && behind > 0 {
364            Ok(Some(RemoteComparison::Diverged(ahead, behind)))
365        } else if ahead > 0 {
366            Ok(Some(RemoteComparison::Ahead(ahead)))
367        } else if behind > 0 {
368            Ok(Some(RemoteComparison::Behind(behind)))
369        } else {
370            Ok(Some(RemoteComparison::UpToDate))
371        }
372    }
373
374    pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
375        self.remote_callbacks_impl(false)
376    }
377
378    pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
379        self.remote_callbacks_impl(true)
380    }
381
382    fn remote_callbacks_impl(
383        &self,
384        verbose: bool,
385    ) -> Result<git2::RemoteCallbacks<'static>> {
386        let config = self.git_repo.config()?;
387
388        let mut callbacks = git2::RemoteCallbacks::new();
389        callbacks.credentials(move |url, username_from_url, allowed| {
390            if verbose {
391                eprintln!("DEBUG: Credential callback invoked");
392                eprintln!("  URL: {}", url);
393                eprintln!("  Username from URL: {:?}", username_from_url);
394                eprintln!("  Allowed types: {:?}", allowed);
395            }
396
397            // Try SSH key from agent (only if SSH_AUTH_SOCK is set)
398            if allowed.contains(git2::CredentialType::SSH_KEY) {
399                if let Some(username) = username_from_url {
400                    // Check if SSH agent is actually available
401                    if std::env::var("SSH_AUTH_SOCK").is_ok() {
402                        if verbose {
403                            eprintln!(
404                                "  Attempting: SSH key from agent for user '{}'",
405                                username
406                            );
407                        }
408                        match git2::Cred::ssh_key_from_agent(username) {
409                            Ok(cred) => {
410                                if verbose {
411                                    eprintln!("  SUCCESS: SSH key from agent");
412                                }
413                                return Ok(cred);
414                            },
415                            Err(e) => {
416                                if verbose {
417                                    eprintln!("  FAILED: SSH key from agent - {}", e);
418                                }
419                            },
420                        }
421                    } else if verbose {
422                        eprintln!(
423                            "  SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)"
424                        );
425                    }
426                } else if verbose {
427                    eprintln!("  SKIPPED: SSH key from agent (no username provided)");
428                }
429
430                // Try SSH key files directly
431                if let Some(username) = username_from_url
432                    && let Ok(home) = std::env::var("HOME")
433                {
434                    let key_paths = vec![
435                        format!("{}/.ssh/id_ed25519", home),
436                        format!("{}/.ssh/id_rsa", home),
437                        format!("{}/.ssh/id_ecdsa", home),
438                    ];
439
440                    for key_path in key_paths {
441                        if path::Path::new(&key_path).exists() {
442                            if verbose {
443                                eprintln!("  Attempting: SSH key file at {}", key_path);
444                            }
445                            match git2::Cred::ssh_key(
446                                username,
447                                None, // no public key path
448                                path::Path::new(&key_path),
449                                None, // no passphrase
450                            ) {
451                                Ok(cred) => {
452                                    if verbose {
453                                        eprintln!("  SUCCESS: SSH key file");
454                                    }
455                                    return Ok(cred);
456                                },
457                                Err(e) => {
458                                    if verbose {
459                                        eprintln!("  FAILED: SSH key file - {}", e);
460                                    }
461                                },
462                            }
463                        }
464                    }
465                }
466            }
467
468            // Try credential helper
469            if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
470                || allowed.contains(git2::CredentialType::SSH_KEY)
471                || allowed.contains(git2::CredentialType::DEFAULT)
472            {
473                if verbose {
474                    eprintln!("  Attempting: Credential helper");
475                }
476                match git2::Cred::credential_helper(&config, url, username_from_url) {
477                    Ok(cred) => {
478                        if verbose {
479                            eprintln!("  SUCCESS: Credential helper");
480                        }
481                        return Ok(cred);
482                    },
483                    Err(e) => {
484                        if verbose {
485                            eprintln!("  FAILED: Credential helper - {}", e);
486                        }
487                    },
488                }
489            }
490
491            // Try username only
492            if allowed.contains(git2::CredentialType::USERNAME) {
493                let username = username_from_url.unwrap_or("git");
494                if verbose {
495                    eprintln!("  Attempting: Username only ('{}')", username);
496                }
497                match git2::Cred::username(username) {
498                    Ok(cred) => {
499                        if verbose {
500                            eprintln!("  SUCCESS: Username");
501                        }
502                        return Ok(cred);
503                    },
504                    Err(e) => {
505                        if verbose {
506                            eprintln!("  FAILED: Username - {}", e);
507                        }
508                    },
509                }
510            }
511
512            // Try default
513            if verbose {
514                eprintln!("  Attempting: Default credentials");
515            }
516            match git2::Cred::default() {
517                Ok(cred) => {
518                    if verbose {
519                        eprintln!("  SUCCESS: Default credentials");
520                    }
521                    Ok(cred)
522                },
523                Err(e) => {
524                    if verbose {
525                        eprintln!("  FAILED: All credential methods exhausted");
526                        eprintln!("  Last error: {}", e);
527                    }
528                    Err(e)
529                },
530            }
531        });
532
533        Ok(callbacks)
534    }
535
536    fn resolve_reference(&self, short_name: &str) -> Result<String> {
537        Ok(self
538            .git_repo
539            .resolve_reference_from_short_name(short_name)?
540            .name()
541            .with_context(|| {
542                format!(
543                    "Cannot resolve head reference for repo at `{}`",
544                    self.work_dir.display()
545                )
546            })?
547            .to_owned())
548    }
549
550    pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
551        let config = self.git_repo.config()?;
552
553        let remote_key = format!("branch.{}.remote", branch_name);
554        let merge_key = format!("branch.{}.merge", branch_name);
555
556        let remote = match config.get_string(&remote_key) {
557            Ok(name) => name,
558            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
559            Err(err) => return Err(err.into()),
560        };
561
562        let merge_ref = match config.get_string(&merge_key) {
563            Ok(name) => name,
564            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
565            Err(err) => return Err(err.into()),
566        };
567
568        let branch_short = merge_ref
569            .strip_prefix("refs/heads/")
570            .unwrap_or(&merge_ref)
571            .to_owned();
572
573        let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
574
575        Ok(Some(TrackingBranch { remote, remote_ref }))
576    }
577
578    fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
579        let config = self.git_repo.config()?;
580
581        // First check branch-specific rebase setting (highest priority)
582        let branch_rebase_key = format!("branch.{}.rebase", branch_name);
583        if let Ok(value) = config.get_string(&branch_rebase_key) {
584            return Ok(parse_rebase_config(&value));
585        }
586
587        // Then check global pull.rebase setting
588        if let Ok(value) = config.get_string("pull.rebase") {
589            return Ok(parse_rebase_config(&value));
590        }
591
592        // Try as boolean for backward compatibility
593        if let Ok(value) = config.get_bool("pull.rebase") {
594            return Ok(if value {
595                PullStrategy::Rebase
596            } else {
597                PullStrategy::Merge
598            });
599        }
600
601        // Default to merge
602        Ok(PullStrategy::Merge)
603    }
604}
605
606impl fmt::Debug for Repo {
607    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
608        f.debug_struct("Repo")
609            .field("work_dir", &self.work_dir)
610            .field("head", &self.head)
611            .field("subrepos", &self.subrepos)
612            .finish()
613    }
614}
615
616pub struct TrackingBranch {
617    pub remote: String,
618    pub remote_ref: String,
619}
620
621#[derive(Debug, Clone, PartialEq)]
622enum PullStrategy {
623    Merge,
624    Rebase,
625}
626
627fn parse_rebase_config(value: &str) -> PullStrategy {
628    match value.to_lowercase().as_str() {
629        "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
630        "false" => PullStrategy::Merge,
631        _ => PullStrategy::Merge, // Default to merge for unknown values
632    }
633}