git_wok/
repo.rs

1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9    UpToDate,
10    FastForward,
11    Merged,
12    Rebased,
13    Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18    UpToDate,
19    Ahead(usize),
20    Behind(usize),
21    Diverged(usize, usize),
22    NoRemote,
23}
24
25pub struct Repo {
26    pub git_repo: git2::Repository,
27    pub work_dir: path::PathBuf,
28    pub head: String,
29    pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33    pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34        let git_repo = git2::Repository::open(work_dir)
35            .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37        let head = match head_name {
38            Some(name) => String::from(name),
39            None => {
40                if git_repo.head_detached().with_context(|| {
41                    format!(
42                        "Cannot determine head state for repo at `{}`",
43                        work_dir.display()
44                    )
45                })? {
46                    bail!(
47                        "Cannot operate on a detached head for repo at `{}`",
48                        work_dir.display()
49                    )
50                }
51
52                String::from(git_repo.head().with_context(|| {
53                    format!(
54                        "Cannot find the head branch for repo at `{}`. Is it detached?",
55                        work_dir.display()
56                    )
57                })?.shorthand().with_context(|| {
58                    format!(
59                        "Cannot find a human readable representation of the head ref for repo at `{}`",
60                        work_dir.display(),
61                    )
62                })?)
63            },
64        };
65
66        let subrepos = git_repo
67            .submodules()
68            .with_context(|| {
69                format!(
70                    "Cannot load submodules for repo at `{}`",
71                    work_dir.display()
72                )
73            })?
74            .iter()
75            .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
76            .collect::<Result<Vec<Repo>>>()?;
77
78        Ok(Repo {
79            git_repo,
80            work_dir: path::PathBuf::from(work_dir),
81            head,
82            subrepos,
83        })
84    }
85
86    pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87        self.subrepos
88            .iter()
89            .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90    }
91
92    pub fn sync(&self) -> Result<()> {
93        self.switch(&self.head)?;
94        Ok(())
95    }
96
97    pub fn switch(&self, head: &str) -> Result<()> {
98        self.git_repo.set_head(&self.resolve_reference(head)?)?;
99        self.git_repo.checkout_head(None)?;
100        Ok(())
101    }
102
103    pub fn fetch(&self) -> Result<()> {
104        // Get the remote for the current branch
105        let head_ref = self.git_repo.head()?;
106        let branch_name = head_ref.shorthand().with_context(|| {
107            format!(
108                "Cannot get branch name for repo at `{}`",
109                self.work_dir.display()
110            )
111        })?;
112
113        let tracking = match self.tracking_branch(branch_name)? {
114            Some(tracking) => tracking,
115            None => {
116                // No upstream configured, skip fetch
117                return Ok(());
118            },
119        };
120
121        // Check if remote exists
122        match self.git_repo.find_remote(&tracking.remote) {
123            Ok(mut remote) => {
124                let mut fetch_options = git2::FetchOptions::new();
125                fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127                remote
128                    .fetch::<&str>(&[], Some(&mut fetch_options), None)
129                    .with_context(|| {
130                        format!(
131                            "Failed to fetch from remote '{}' for repo at `{}`\n\
132                            \n\
133                            Possible causes:\n\
134                            - SSH agent not running or not accessible (check SSH_AUTH_SOCK)\n\
135                            - SSH keys not properly configured in ~/.ssh/\n\
136                            - Credential helper not configured (git config credential.helper)\n\
137                            - Network/firewall issues\n\
138                            \n\
139                            Try running: git fetch --verbose\n\
140                            Or check authentication with: git-wok test-auth",
141                            tracking.remote,
142                            self.work_dir.display()
143                        )
144                    })?;
145            },
146            Err(_) => {
147                // No remote configured, skip fetch
148                return Ok(());
149            },
150        }
151
152        Ok(())
153    }
154
155    fn rebase(
156        &self,
157        _branch_name: &str,
158        remote_commit: &git2::Commit,
159    ) -> Result<MergeResult> {
160        let _local_commit = self.git_repo.head()?.peel_to_commit()?;
161        let remote_oid = remote_commit.id();
162
163        // Prepare annotated commit for rebase
164        let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
165
166        // Initialize rebase operation
167        let signature = self.git_repo.signature()?;
168        let mut rebase = self.git_repo.rebase(
169            None,                    // branch to rebase (None = HEAD)
170            Some(&remote_annotated), // upstream
171            None,                    // onto (None = upstream)
172            None,                    // options
173        )?;
174
175        // Process each commit in the rebase
176        let mut has_conflicts = false;
177        while let Some(op) = rebase.next() {
178            match op {
179                Ok(_rebase_op) => {
180                    // Check for conflicts
181                    let index = self.git_repo.index()?;
182                    if index.has_conflicts() {
183                        has_conflicts = true;
184                        break;
185                    }
186
187                    // Commit the rebased changes
188                    if rebase.commit(None, &signature, None).is_err() {
189                        has_conflicts = true;
190                        break;
191                    }
192                },
193                Err(_) => {
194                    has_conflicts = true;
195                    break;
196                },
197            }
198        }
199
200        if has_conflicts {
201            // Leave repository in state with conflicts for user to resolve
202            return Ok(MergeResult::Conflicts);
203        }
204
205        // Finish the rebase
206        rebase.finish(Some(&signature))?;
207
208        Ok(MergeResult::Rebased)
209    }
210
211    pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
212        // First, fetch the latest changes
213        self.fetch()?;
214
215        // Resolve the tracking branch reference
216        let tracking = match self.tracking_branch(branch_name)? {
217            Some(tracking) => tracking,
218            None => {
219                // No upstream configured, treat as up to date
220                return Ok(MergeResult::UpToDate);
221            },
222        };
223
224        // Check if remote branch exists
225        let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
226        {
227            Ok(oid) => oid,
228            Err(_) => {
229                // No remote branch, just return up to date
230                return Ok(MergeResult::UpToDate);
231            },
232        };
233
234        let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
235        let local_commit = self.git_repo.head()?.peel_to_commit()?;
236
237        // Check if we're already up to date
238        if local_commit.id() == remote_commit.id() {
239            return Ok(MergeResult::UpToDate);
240        }
241
242        // Check if we can fast-forward (works for both merge and rebase)
243        if self
244            .git_repo
245            .graph_descendant_of(remote_commit.id(), local_commit.id())?
246        {
247            // Fast-forward merge
248            self.git_repo.reference(
249                &format!("refs/heads/{}", branch_name),
250                remote_commit.id(),
251                true,
252                &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
253            )?;
254            self.git_repo
255                .set_head(&format!("refs/heads/{}", branch_name))?;
256            let mut checkout = CheckoutBuilder::new();
257            checkout.force();
258            self.git_repo.checkout_head(Some(&mut checkout))?;
259            return Ok(MergeResult::FastForward);
260        }
261
262        // Determine pull strategy from git config
263        let pull_strategy = self.get_pull_strategy(branch_name)?;
264
265        match pull_strategy {
266            PullStrategy::Rebase => {
267                // Perform rebase
268                self.rebase(branch_name, &remote_commit)
269            },
270            PullStrategy::Merge => {
271                // Perform merge (existing logic)
272                self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
273            },
274        }
275    }
276
277    fn do_merge(
278        &self,
279        branch_name: &str,
280        local_commit: &git2::Commit,
281        remote_commit: &git2::Commit,
282        tracking: &TrackingBranch,
283    ) -> Result<MergeResult> {
284        // Perform a merge
285        let mut merge_opts = git2::MergeOptions::new();
286        merge_opts.fail_on_conflict(false); // Don't fail on conflicts, we'll handle them
287
288        let _merge_result = self.git_repo.merge_commits(
289            local_commit,
290            remote_commit,
291            Some(&merge_opts),
292        )?;
293
294        // Check if there are conflicts by examining the index
295        let mut index = self.git_repo.index()?;
296        let has_conflicts = index.has_conflicts();
297
298        if !has_conflicts {
299            // No conflicts, merge was successful
300            let signature = self.git_repo.signature()?;
301            let tree_id = index.write_tree()?;
302            let tree = self.git_repo.find_tree(tree_id)?;
303
304            self.git_repo.commit(
305                Some(&format!("refs/heads/{}", branch_name)),
306                &signature,
307                &signature,
308                &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
309                &tree,
310                &[local_commit, remote_commit],
311            )?;
312
313            self.git_repo.cleanup_state()?;
314
315            Ok(MergeResult::Merged)
316        } else {
317            // There are conflicts
318            Ok(MergeResult::Conflicts)
319        }
320    }
321
322    pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
323        if let Some(tracking) = self.tracking_branch(branch_name)? {
324            Ok(tracking.remote)
325        } else {
326            // Fall back to origin if no tracking branch is configured
327            Ok("origin".to_string())
328        }
329    }
330
331    /// Get the ahead/behind count relative to the remote tracking branch
332    pub fn get_remote_comparison(
333        &self,
334        branch_name: &str,
335    ) -> Result<Option<RemoteComparison>> {
336        // Get the tracking branch info
337        let tracking = match self.tracking_branch(branch_name)? {
338            Some(tracking) => tracking,
339            None => return Ok(None), // No tracking branch configured
340        };
341
342        // Check if remote branch exists
343        let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
344            Ok(oid) => oid,
345            Err(_) => {
346                // Remote branch doesn't exist
347                return Ok(Some(RemoteComparison::NoRemote));
348            },
349        };
350
351        // Get local branch OID
352        let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
353
354        // If they're the same, we're up to date
355        if local_oid == remote_oid {
356            return Ok(Some(RemoteComparison::UpToDate));
357        }
358
359        // Calculate ahead/behind using git's graph functions
360        let (ahead, behind) =
361            self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
362
363        if ahead > 0 && behind > 0 {
364            Ok(Some(RemoteComparison::Diverged(ahead, behind)))
365        } else if ahead > 0 {
366            Ok(Some(RemoteComparison::Ahead(ahead)))
367        } else if behind > 0 {
368            Ok(Some(RemoteComparison::Behind(behind)))
369        } else {
370            Ok(Some(RemoteComparison::UpToDate))
371        }
372    }
373
374    pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
375        self.remote_callbacks_impl(false)
376    }
377
378    pub fn remote_callbacks_verbose(&self) -> Result<git2::RemoteCallbacks<'static>> {
379        self.remote_callbacks_impl(true)
380    }
381
382    fn remote_callbacks_impl(&self, verbose: bool) -> Result<git2::RemoteCallbacks<'static>> {
383        let config = self.git_repo.config()?;
384
385        let mut callbacks = git2::RemoteCallbacks::new();
386        callbacks.credentials(move |url, username_from_url, allowed| {
387            if verbose {
388                eprintln!("DEBUG: Credential callback invoked");
389                eprintln!("  URL: {}", url);
390                eprintln!("  Username from URL: {:?}", username_from_url);
391                eprintln!("  Allowed types: {:?}", allowed);
392            }
393
394            // Try SSH key from agent (only if SSH_AUTH_SOCK is set)
395            if allowed.contains(git2::CredentialType::SSH_KEY) {
396                if let Some(username) = username_from_url {
397                    // Check if SSH agent is actually available
398                    if std::env::var("SSH_AUTH_SOCK").is_ok() {
399                        if verbose {
400                            eprintln!("  Attempting: SSH key from agent for user '{}'", username);
401                        }
402                        match git2::Cred::ssh_key_from_agent(username) {
403                            Ok(cred) => {
404                                if verbose {
405                                    eprintln!("  SUCCESS: SSH key from agent");
406                                }
407                                return Ok(cred);
408                            }
409                            Err(e) => {
410                                if verbose {
411                                    eprintln!("  FAILED: SSH key from agent - {}", e);
412                                }
413                            }
414                        }
415                    } else if verbose {
416                        eprintln!("  SKIPPED: SSH key from agent (SSH_AUTH_SOCK not set)");
417                    }
418                } else if verbose {
419                    eprintln!("  SKIPPED: SSH key from agent (no username provided)");
420                }
421
422                // Try SSH key files directly
423                if let Some(username) = username_from_url
424                    && let Ok(home) = std::env::var("HOME")
425                {
426                    let key_paths = vec![
427                        format!("{}/.ssh/id_ed25519", home),
428                        format!("{}/.ssh/id_rsa", home),
429                        format!("{}/.ssh/id_ecdsa", home),
430                    ];
431
432                    for key_path in key_paths {
433                        if path::Path::new(&key_path).exists() {
434                            if verbose {
435                                eprintln!("  Attempting: SSH key file at {}", key_path);
436                            }
437                            match git2::Cred::ssh_key(
438                                username,
439                                None, // no public key path
440                                path::Path::new(&key_path),
441                                None, // no passphrase
442                            ) {
443                                Ok(cred) => {
444                                    if verbose {
445                                        eprintln!("  SUCCESS: SSH key file");
446                                    }
447                                    return Ok(cred);
448                                }
449                                Err(e) => {
450                                    if verbose {
451                                        eprintln!("  FAILED: SSH key file - {}", e);
452                                    }
453                                }
454                            }
455                        }
456                    }
457                }
458            }
459
460            // Try credential helper
461            if allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
462                || allowed.contains(git2::CredentialType::SSH_KEY)
463                || allowed.contains(git2::CredentialType::DEFAULT)
464            {
465                if verbose {
466                    eprintln!("  Attempting: Credential helper");
467                }
468                match git2::Cred::credential_helper(&config, url, username_from_url) {
469                    Ok(cred) => {
470                        if verbose {
471                            eprintln!("  SUCCESS: Credential helper");
472                        }
473                        return Ok(cred);
474                    }
475                    Err(e) => {
476                        if verbose {
477                            eprintln!("  FAILED: Credential helper - {}", e);
478                        }
479                    }
480                }
481            }
482
483            // Try username only
484            if allowed.contains(git2::CredentialType::USERNAME) {
485                let username = username_from_url.unwrap_or("git");
486                if verbose {
487                    eprintln!("  Attempting: Username only ('{}')", username);
488                }
489                match git2::Cred::username(username) {
490                    Ok(cred) => {
491                        if verbose {
492                            eprintln!("  SUCCESS: Username");
493                        }
494                        return Ok(cred);
495                    }
496                    Err(e) => {
497                        if verbose {
498                            eprintln!("  FAILED: Username - {}", e);
499                        }
500                    }
501                }
502            }
503
504            // Try default
505            if verbose {
506                eprintln!("  Attempting: Default credentials");
507            }
508            match git2::Cred::default() {
509                Ok(cred) => {
510                    if verbose {
511                        eprintln!("  SUCCESS: Default credentials");
512                    }
513                    Ok(cred)
514                }
515                Err(e) => {
516                    if verbose {
517                        eprintln!("  FAILED: All credential methods exhausted");
518                        eprintln!("  Last error: {}", e);
519                    }
520                    Err(e)
521                }
522            }
523        });
524
525        Ok(callbacks)
526    }
527
528    fn resolve_reference(&self, short_name: &str) -> Result<String> {
529        Ok(self
530            .git_repo
531            .resolve_reference_from_short_name(short_name)?
532            .name()
533            .with_context(|| {
534                format!(
535                    "Cannot resolve head reference for repo at `{}`",
536                    self.work_dir.display()
537                )
538            })?
539            .to_owned())
540    }
541
542    pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
543        let config = self.git_repo.config()?;
544
545        let remote_key = format!("branch.{}.remote", branch_name);
546        let merge_key = format!("branch.{}.merge", branch_name);
547
548        let remote = match config.get_string(&remote_key) {
549            Ok(name) => name,
550            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
551            Err(err) => return Err(err.into()),
552        };
553
554        let merge_ref = match config.get_string(&merge_key) {
555            Ok(name) => name,
556            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
557            Err(err) => return Err(err.into()),
558        };
559
560        let branch_short = merge_ref
561            .strip_prefix("refs/heads/")
562            .unwrap_or(&merge_ref)
563            .to_owned();
564
565        let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
566
567        Ok(Some(TrackingBranch { remote, remote_ref }))
568    }
569
570    fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
571        let config = self.git_repo.config()?;
572
573        // First check branch-specific rebase setting (highest priority)
574        let branch_rebase_key = format!("branch.{}.rebase", branch_name);
575        if let Ok(value) = config.get_string(&branch_rebase_key) {
576            return Ok(parse_rebase_config(&value));
577        }
578
579        // Then check global pull.rebase setting
580        if let Ok(value) = config.get_string("pull.rebase") {
581            return Ok(parse_rebase_config(&value));
582        }
583
584        // Try as boolean for backward compatibility
585        if let Ok(value) = config.get_bool("pull.rebase") {
586            return Ok(if value {
587                PullStrategy::Rebase
588            } else {
589                PullStrategy::Merge
590            });
591        }
592
593        // Default to merge
594        Ok(PullStrategy::Merge)
595    }
596}
597
598impl fmt::Debug for Repo {
599    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
600        f.debug_struct("Repo")
601            .field("work_dir", &self.work_dir)
602            .field("head", &self.head)
603            .field("subrepos", &self.subrepos)
604            .finish()
605    }
606}
607
608pub struct TrackingBranch {
609    pub remote: String,
610    pub remote_ref: String,
611}
612
613#[derive(Debug, Clone, PartialEq)]
614enum PullStrategy {
615    Merge,
616    Rebase,
617}
618
619fn parse_rebase_config(value: &str) -> PullStrategy {
620    match value.to_lowercase().as_str() {
621        "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
622        "false" => PullStrategy::Merge,
623        _ => PullStrategy::Merge, // Default to merge for unknown values
624    }
625}