git_wok/
repo.rs

1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9    UpToDate,
10    FastForward,
11    Merged,
12    Rebased,
13    Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18    UpToDate,
19    Ahead(usize),
20    Behind(usize),
21    Diverged(usize, usize),
22    NoRemote,
23}
24
25pub struct Repo {
26    pub git_repo: git2::Repository,
27    pub work_dir: path::PathBuf,
28    pub head: String,
29    pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33    pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34        let git_repo = git2::Repository::open(work_dir)
35            .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37        let head = match head_name {
38            Some(name) => String::from(name),
39            None => {
40                if git_repo.head_detached().with_context(|| {
41                    format!(
42                        "Cannot determine head state for repo at `{}`",
43                        work_dir.display()
44                    )
45                })? {
46                    bail!(
47                        "Cannot operate on a detached head for repo at `{}`",
48                        work_dir.display()
49                    )
50                }
51
52                String::from(git_repo.head().with_context(|| {
53                    format!(
54                        "Cannot find the head branch for repo at `{}`. Is it detached?",
55                        work_dir.display()
56                    )
57                })?.shorthand().with_context(|| {
58                    format!(
59                        "Cannot find a human readable representation of the head ref for repo at `{}`",
60                        work_dir.display(),
61                    )
62                })?)
63            },
64        };
65
66        let subrepos = git_repo
67            .submodules()
68            .with_context(|| {
69                format!(
70                    "Cannot load submodules for repo at `{}`",
71                    work_dir.display()
72                )
73            })?
74            .iter()
75            .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
76            .collect::<Result<Vec<Repo>>>()?;
77
78        Ok(Repo {
79            git_repo,
80            work_dir: path::PathBuf::from(work_dir),
81            head,
82            subrepos,
83        })
84    }
85
86    pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87        self.subrepos
88            .iter()
89            .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90    }
91
92    pub fn sync(&self) -> Result<()> {
93        self.switch(&self.head)?;
94        Ok(())
95    }
96
97    pub fn switch(&self, head: &str) -> Result<()> {
98        self.git_repo.set_head(&self.resolve_reference(head)?)?;
99        self.git_repo.checkout_head(None)?;
100        Ok(())
101    }
102
103    pub fn fetch(&self) -> Result<()> {
104        // Get the remote for the current branch
105        let head_ref = self.git_repo.head()?;
106        let branch_name = head_ref.shorthand().with_context(|| {
107            format!(
108                "Cannot get branch name for repo at `{}`",
109                self.work_dir.display()
110            )
111        })?;
112
113        let tracking = match self.tracking_branch(branch_name)? {
114            Some(tracking) => tracking,
115            None => {
116                // No upstream configured, skip fetch
117                return Ok(());
118            },
119        };
120
121        // Check if remote exists
122        match self.git_repo.find_remote(&tracking.remote) {
123            Ok(mut remote) => {
124                let mut fetch_options = git2::FetchOptions::new();
125                fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127                remote
128                    .fetch::<&str>(&[], Some(&mut fetch_options), None)
129                    .with_context(|| {
130                        format!(
131                            "Failed to fetch from remote '{}' for repo at `{}`",
132                            tracking.remote,
133                            self.work_dir.display()
134                        )
135                    })?;
136            },
137            Err(_) => {
138                // No remote configured, skip fetch
139                return Ok(());
140            },
141        }
142
143        Ok(())
144    }
145
146    fn rebase(
147        &self,
148        _branch_name: &str,
149        remote_commit: &git2::Commit,
150    ) -> Result<MergeResult> {
151        let _local_commit = self.git_repo.head()?.peel_to_commit()?;
152        let remote_oid = remote_commit.id();
153
154        // Prepare annotated commit for rebase
155        let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
156
157        // Initialize rebase operation
158        let signature = self.git_repo.signature()?;
159        let mut rebase = self.git_repo.rebase(
160            None,                    // branch to rebase (None = HEAD)
161            Some(&remote_annotated), // upstream
162            None,                    // onto (None = upstream)
163            None,                    // options
164        )?;
165
166        // Process each commit in the rebase
167        let mut has_conflicts = false;
168        while let Some(op) = rebase.next() {
169            match op {
170                Ok(_rebase_op) => {
171                    // Check for conflicts
172                    let index = self.git_repo.index()?;
173                    if index.has_conflicts() {
174                        has_conflicts = true;
175                        break;
176                    }
177
178                    // Commit the rebased changes
179                    if rebase.commit(None, &signature, None).is_err() {
180                        has_conflicts = true;
181                        break;
182                    }
183                },
184                Err(_) => {
185                    has_conflicts = true;
186                    break;
187                },
188            }
189        }
190
191        if has_conflicts {
192            // Leave repository in state with conflicts for user to resolve
193            return Ok(MergeResult::Conflicts);
194        }
195
196        // Finish the rebase
197        rebase.finish(Some(&signature))?;
198
199        Ok(MergeResult::Rebased)
200    }
201
202    pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
203        // First, fetch the latest changes
204        self.fetch()?;
205
206        // Resolve the tracking branch reference
207        let tracking = match self.tracking_branch(branch_name)? {
208            Some(tracking) => tracking,
209            None => {
210                // No upstream configured, treat as up to date
211                return Ok(MergeResult::UpToDate);
212            },
213        };
214
215        // Check if remote branch exists
216        let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
217        {
218            Ok(oid) => oid,
219            Err(_) => {
220                // No remote branch, just return up to date
221                return Ok(MergeResult::UpToDate);
222            },
223        };
224
225        let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
226        let local_commit = self.git_repo.head()?.peel_to_commit()?;
227
228        // Check if we're already up to date
229        if local_commit.id() == remote_commit.id() {
230            return Ok(MergeResult::UpToDate);
231        }
232
233        // Check if we can fast-forward (works for both merge and rebase)
234        if self
235            .git_repo
236            .graph_descendant_of(remote_commit.id(), local_commit.id())?
237        {
238            // Fast-forward merge
239            self.git_repo.reference(
240                &format!("refs/heads/{}", branch_name),
241                remote_commit.id(),
242                true,
243                &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
244            )?;
245            self.git_repo
246                .set_head(&format!("refs/heads/{}", branch_name))?;
247            let mut checkout = CheckoutBuilder::new();
248            checkout.force();
249            self.git_repo.checkout_head(Some(&mut checkout))?;
250            return Ok(MergeResult::FastForward);
251        }
252
253        // Determine pull strategy from git config
254        let pull_strategy = self.get_pull_strategy(branch_name)?;
255
256        match pull_strategy {
257            PullStrategy::Rebase => {
258                // Perform rebase
259                self.rebase(branch_name, &remote_commit)
260            },
261            PullStrategy::Merge => {
262                // Perform merge (existing logic)
263                self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
264            },
265        }
266    }
267
268    fn do_merge(
269        &self,
270        branch_name: &str,
271        local_commit: &git2::Commit,
272        remote_commit: &git2::Commit,
273        tracking: &TrackingBranch,
274    ) -> Result<MergeResult> {
275        // Perform a merge
276        let mut merge_opts = git2::MergeOptions::new();
277        merge_opts.fail_on_conflict(false); // Don't fail on conflicts, we'll handle them
278
279        let _merge_result = self.git_repo.merge_commits(
280            local_commit,
281            remote_commit,
282            Some(&merge_opts),
283        )?;
284
285        // Check if there are conflicts by examining the index
286        let mut index = self.git_repo.index()?;
287        let has_conflicts = index.has_conflicts();
288
289        if !has_conflicts {
290            // No conflicts, merge was successful
291            let signature = self.git_repo.signature()?;
292            let tree_id = index.write_tree()?;
293            let tree = self.git_repo.find_tree(tree_id)?;
294
295            self.git_repo.commit(
296                Some(&format!("refs/heads/{}", branch_name)),
297                &signature,
298                &signature,
299                &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
300                &tree,
301                &[local_commit, remote_commit],
302            )?;
303
304            self.git_repo.cleanup_state()?;
305
306            Ok(MergeResult::Merged)
307        } else {
308            // There are conflicts
309            Ok(MergeResult::Conflicts)
310        }
311    }
312
313    pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
314        if let Some(tracking) = self.tracking_branch(branch_name)? {
315            Ok(tracking.remote)
316        } else {
317            // Fall back to origin if no tracking branch is configured
318            Ok("origin".to_string())
319        }
320    }
321
322    /// Get the ahead/behind count relative to the remote tracking branch
323    pub fn get_remote_comparison(
324        &self,
325        branch_name: &str,
326    ) -> Result<Option<RemoteComparison>> {
327        // Get the tracking branch info
328        let tracking = match self.tracking_branch(branch_name)? {
329            Some(tracking) => tracking,
330            None => return Ok(None), // No tracking branch configured
331        };
332
333        // Check if remote branch exists
334        let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
335            Ok(oid) => oid,
336            Err(_) => {
337                // Remote branch doesn't exist
338                return Ok(Some(RemoteComparison::NoRemote));
339            },
340        };
341
342        // Get local branch OID
343        let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
344
345        // If they're the same, we're up to date
346        if local_oid == remote_oid {
347            return Ok(Some(RemoteComparison::UpToDate));
348        }
349
350        // Calculate ahead/behind using git's graph functions
351        let (ahead, behind) =
352            self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
353
354        if ahead > 0 && behind > 0 {
355            Ok(Some(RemoteComparison::Diverged(ahead, behind)))
356        } else if ahead > 0 {
357            Ok(Some(RemoteComparison::Ahead(ahead)))
358        } else if behind > 0 {
359            Ok(Some(RemoteComparison::Behind(behind)))
360        } else {
361            Ok(Some(RemoteComparison::UpToDate))
362        }
363    }
364
365    pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
366        let config = self.git_repo.config()?;
367
368        let mut callbacks = git2::RemoteCallbacks::new();
369        callbacks.credentials(move |url, username_from_url, allowed| {
370            if allowed.contains(git2::CredentialType::SSH_KEY)
371                && let Some(username) = username_from_url
372                && let Ok(cred) = git2::Cred::ssh_key_from_agent(username)
373            {
374                return Ok(cred);
375            }
376
377            if (allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
378                || allowed.contains(git2::CredentialType::SSH_KEY)
379                || allowed.contains(git2::CredentialType::DEFAULT))
380                && let Ok(cred) =
381                    git2::Cred::credential_helper(&config, url, username_from_url)
382            {
383                return Ok(cred);
384            }
385
386            if allowed.contains(git2::CredentialType::USERNAME) {
387                if let Some(username) = username_from_url {
388                    return git2::Cred::username(username);
389                } else {
390                    return git2::Cred::username("git");
391                }
392            }
393
394            git2::Cred::default()
395        });
396
397        Ok(callbacks)
398    }
399
400    fn resolve_reference(&self, short_name: &str) -> Result<String> {
401        Ok(self
402            .git_repo
403            .resolve_reference_from_short_name(short_name)?
404            .name()
405            .with_context(|| {
406                format!(
407                    "Cannot resolve head reference for repo at `{}`",
408                    self.work_dir.display()
409                )
410            })?
411            .to_owned())
412    }
413
414    pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
415        let config = self.git_repo.config()?;
416
417        let remote_key = format!("branch.{}.remote", branch_name);
418        let merge_key = format!("branch.{}.merge", branch_name);
419
420        let remote = match config.get_string(&remote_key) {
421            Ok(name) => name,
422            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
423            Err(err) => return Err(err.into()),
424        };
425
426        let merge_ref = match config.get_string(&merge_key) {
427            Ok(name) => name,
428            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
429            Err(err) => return Err(err.into()),
430        };
431
432        let branch_short = merge_ref
433            .strip_prefix("refs/heads/")
434            .unwrap_or(&merge_ref)
435            .to_owned();
436
437        let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
438
439        Ok(Some(TrackingBranch { remote, remote_ref }))
440    }
441
442    fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
443        let config = self.git_repo.config()?;
444
445        // First check branch-specific rebase setting (highest priority)
446        let branch_rebase_key = format!("branch.{}.rebase", branch_name);
447        if let Ok(value) = config.get_string(&branch_rebase_key) {
448            return Ok(parse_rebase_config(&value));
449        }
450
451        // Then check global pull.rebase setting
452        if let Ok(value) = config.get_string("pull.rebase") {
453            return Ok(parse_rebase_config(&value));
454        }
455
456        // Try as boolean for backward compatibility
457        if let Ok(value) = config.get_bool("pull.rebase") {
458            return Ok(if value {
459                PullStrategy::Rebase
460            } else {
461                PullStrategy::Merge
462            });
463        }
464
465        // Default to merge
466        Ok(PullStrategy::Merge)
467    }
468}
469
470impl fmt::Debug for Repo {
471    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
472        f.debug_struct("Repo")
473            .field("work_dir", &self.work_dir)
474            .field("head", &self.head)
475            .field("subrepos", &self.subrepos)
476            .finish()
477    }
478}
479
480pub struct TrackingBranch {
481    pub remote: String,
482    pub remote_ref: String,
483}
484
485#[derive(Debug, Clone, PartialEq)]
486enum PullStrategy {
487    Merge,
488    Rebase,
489}
490
491fn parse_rebase_config(value: &str) -> PullStrategy {
492    match value.to_lowercase().as_str() {
493        "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
494        "false" => PullStrategy::Merge,
495        _ => PullStrategy::Merge, // Default to merge for unknown values
496    }
497}