rung_git/
repository.rs

1//! Repository wrapper providing high-level git operations.
2
3use std::path::Path;
4
5use git2::{BranchType, Oid, RepositoryState, Signature};
6
7use crate::error::{Error, Result};
8
9/// High-level wrapper around a git repository.
10pub struct Repository {
11    inner: git2::Repository,
12}
13
14impl Repository {
15    /// Open a repository at the given path.
16    ///
17    /// # Errors
18    /// Returns error if no repository found at path or any parent.
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let inner = git2::Repository::discover(path)?;
21        Ok(Self { inner })
22    }
23
24    /// Open the repository containing the current directory.
25    ///
26    /// # Errors
27    /// Returns error if not inside a git repository.
28    pub fn open_current() -> Result<Self> {
29        Self::open(".")
30    }
31
32    /// Get the path to the repository root (workdir).
33    #[must_use]
34    pub fn workdir(&self) -> Option<&Path> {
35        self.inner.workdir()
36    }
37
38    /// Get the path to the .git directory.
39    #[must_use]
40    pub fn git_dir(&self) -> &Path {
41        self.inner.path()
42    }
43
44    /// Get the current repository state.
45    #[must_use]
46    pub fn state(&self) -> RepositoryState {
47        self.inner.state()
48    }
49
50    /// Check if there's a rebase in progress.
51    #[must_use]
52    pub fn is_rebasing(&self) -> bool {
53        matches!(
54            self.state(),
55            RepositoryState::Rebase
56                | RepositoryState::RebaseInteractive
57                | RepositoryState::RebaseMerge
58        )
59    }
60
61    // === Branch operations ===
62
63    /// Get the name of the current branch.
64    ///
65    /// # Errors
66    /// Returns error if HEAD is detached.
67    pub fn current_branch(&self) -> Result<String> {
68        let head = self.inner.head()?;
69        if !head.is_branch() {
70            return Err(Error::DetachedHead);
71        }
72
73        head.shorthand()
74            .map(String::from)
75            .ok_or(Error::DetachedHead)
76    }
77
78    /// Get the commit SHA for a branch.
79    ///
80    /// # Errors
81    /// Returns error if branch doesn't exist.
82    pub fn branch_commit(&self, branch_name: &str) -> Result<Oid> {
83        let branch = self
84            .inner
85            .find_branch(branch_name, BranchType::Local)
86            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
87
88        branch
89            .get()
90            .target()
91            .ok_or_else(|| Error::BranchNotFound(branch_name.into()))
92    }
93
94    /// Get the commit ID of a remote branch tip.
95    ///
96    /// # Errors
97    /// Returns error if branch not found.
98    pub fn remote_branch_commit(&self, branch_name: &str) -> Result<Oid> {
99        let ref_name = format!("refs/remotes/origin/{branch_name}");
100        let reference = self
101            .inner
102            .find_reference(&ref_name)
103            .map_err(|_| Error::BranchNotFound(format!("origin/{branch_name}")))?;
104
105        reference
106            .target()
107            .ok_or_else(|| Error::BranchNotFound(format!("origin/{branch_name}")))
108    }
109
110    /// Create a new branch at the current HEAD.
111    ///
112    /// # Errors
113    /// Returns error if branch creation fails.
114    pub fn create_branch(&self, name: &str) -> Result<Oid> {
115        let head_commit = self.inner.head()?.peel_to_commit()?;
116        let branch = self.inner.branch(name, &head_commit, false)?;
117
118        branch
119            .get()
120            .target()
121            .ok_or_else(|| Error::BranchNotFound(name.into()))
122    }
123
124    /// Checkout a branch.
125    ///
126    /// # Errors
127    /// Returns error if checkout fails.
128    pub fn checkout(&self, branch_name: &str) -> Result<()> {
129        let branch = self
130            .inner
131            .find_branch(branch_name, BranchType::Local)
132            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
133
134        let reference = branch.get();
135        let object = reference.peel(git2::ObjectType::Commit)?;
136
137        self.inner.checkout_tree(&object, None)?;
138        self.inner.set_head(&format!("refs/heads/{branch_name}"))?;
139
140        Ok(())
141    }
142
143    /// List all local branches.
144    ///
145    /// # Errors
146    /// Returns error if branch listing fails.
147    pub fn list_branches(&self) -> Result<Vec<String>> {
148        let branches = self.inner.branches(Some(BranchType::Local))?;
149
150        let names: Vec<String> = branches
151            .filter_map(std::result::Result::ok)
152            .filter_map(|(b, _)| b.name().ok().flatten().map(String::from))
153            .collect();
154
155        Ok(names)
156    }
157
158    /// Check if a branch exists.
159    #[must_use]
160    pub fn branch_exists(&self, name: &str) -> bool {
161        self.inner.find_branch(name, BranchType::Local).is_ok()
162    }
163
164    /// Delete a local branch.
165    ///
166    /// # Errors
167    /// Returns error if branch deletion fails.
168    pub fn delete_branch(&self, name: &str) -> Result<()> {
169        let mut branch = self.inner.find_branch(name, BranchType::Local)?;
170        branch.delete()?;
171        Ok(())
172    }
173
174    // === Working directory state ===
175
176    /// Check if the working directory is clean (no modified or staged files).
177    ///
178    /// Untracked files are ignored - only tracked files that have been
179    /// modified or staged count as "dirty".
180    ///
181    /// # Errors
182    /// Returns error if status check fails.
183    pub fn is_clean(&self) -> Result<bool> {
184        let mut opts = git2::StatusOptions::new();
185        opts.include_untracked(false)
186            .include_ignored(false)
187            .include_unmodified(false)
188            .exclude_submodules(true);
189        let statuses = self.inner.statuses(Some(&mut opts))?;
190
191        // Check if any status indicates modified/staged files
192        for entry in statuses.iter() {
193            let status = entry.status();
194            // These indicate actual changes to tracked files
195            if status.intersects(
196                git2::Status::INDEX_NEW
197                    | git2::Status::INDEX_MODIFIED
198                    | git2::Status::INDEX_DELETED
199                    | git2::Status::INDEX_RENAMED
200                    | git2::Status::INDEX_TYPECHANGE
201                    | git2::Status::WT_MODIFIED
202                    | git2::Status::WT_DELETED
203                    | git2::Status::WT_TYPECHANGE
204                    | git2::Status::WT_RENAMED,
205            ) {
206                return Ok(false);
207            }
208        }
209        Ok(true)
210    }
211
212    /// Ensure working directory is clean, returning error if not.
213    ///
214    /// # Errors
215    /// Returns `DirtyWorkingDirectory` if there are uncommitted changes.
216    pub fn require_clean(&self) -> Result<()> {
217        if self.is_clean()? {
218            Ok(())
219        } else {
220            Err(Error::DirtyWorkingDirectory)
221        }
222    }
223
224    // === Staging operations ===
225
226    /// Stage all changes (tracked and untracked files).
227    ///
228    /// Equivalent to `git add -A`.
229    ///
230    /// # Errors
231    /// Returns error if staging fails.
232    pub fn stage_all(&self) -> Result<()> {
233        let workdir = self.workdir().ok_or(Error::NotARepository)?;
234
235        let output = std::process::Command::new("git")
236            .args(["add", "-A"])
237            .current_dir(workdir)
238            .output()
239            .map_err(|e| Error::Git2(git2::Error::from_str(&e.to_string())))?;
240
241        if output.status.success() {
242            Ok(())
243        } else {
244            let stderr = String::from_utf8_lossy(&output.stderr);
245            Err(Error::Git2(git2::Error::from_str(&stderr)))
246        }
247    }
248
249    /// Check if there are staged changes ready to commit.
250    ///
251    /// # Errors
252    /// Returns error if status check fails.
253    pub fn has_staged_changes(&self) -> Result<bool> {
254        let mut opts = git2::StatusOptions::new();
255        opts.include_untracked(false)
256            .include_ignored(false)
257            .include_unmodified(false);
258        let statuses = self.inner.statuses(Some(&mut opts))?;
259
260        for entry in statuses.iter() {
261            let status = entry.status();
262            if status.intersects(
263                git2::Status::INDEX_NEW
264                    | git2::Status::INDEX_MODIFIED
265                    | git2::Status::INDEX_DELETED
266                    | git2::Status::INDEX_RENAMED
267                    | git2::Status::INDEX_TYPECHANGE,
268            ) {
269                return Ok(true);
270            }
271        }
272        Ok(false)
273    }
274
275    /// Create a commit with the given message on HEAD.
276    ///
277    /// Handles both normal commits (with parent) and initial commits (no parent).
278    ///
279    /// # Errors
280    /// Returns error if commit creation fails.
281    pub fn create_commit(&self, message: &str) -> Result<Oid> {
282        let sig = self.signature()?;
283        let mut index = self.inner.index()?;
284        let tree_id = index.write_tree()?;
285        let tree = self.inner.find_tree(tree_id)?;
286
287        // Handle initial commit case (unborn HEAD)
288        let oid = match self.inner.head().and_then(|h| h.peel_to_commit()) {
289            Ok(parent) => {
290                self.inner
291                    .commit(Some("HEAD"), &sig, &sig, message, &tree, &[&parent])?
292            }
293            Err(_) => {
294                // Initial commit - no parent
295                self.inner
296                    .commit(Some("HEAD"), &sig, &sig, message, &tree, &[])?
297            }
298        };
299
300        Ok(oid)
301    }
302
303    // === Commit operations ===
304
305    /// Get a commit by its SHA.
306    ///
307    /// # Errors
308    /// Returns error if commit not found.
309    pub fn find_commit(&self, oid: Oid) -> Result<git2::Commit<'_>> {
310        Ok(self.inner.find_commit(oid)?)
311    }
312
313    /// Get the commit message from a branch's tip commit.
314    ///
315    /// # Errors
316    /// Returns error if branch doesn't exist or has no commits.
317    pub fn branch_commit_message(&self, branch_name: &str) -> Result<String> {
318        let oid = self.branch_commit(branch_name)?;
319        let commit = self.inner.find_commit(oid)?;
320        commit
321            .message()
322            .map(String::from)
323            .ok_or_else(|| Error::Git2(git2::Error::from_str("commit has no message")))
324    }
325
326    /// Get the merge base between two commits.
327    ///
328    /// # Errors
329    /// Returns error if merge base calculation fails.
330    pub fn merge_base(&self, one: Oid, two: Oid) -> Result<Oid> {
331        Ok(self.inner.merge_base(one, two)?)
332    }
333
334    /// Count commits between two points.
335    ///
336    /// # Errors
337    /// Returns error if revwalk fails.
338    pub fn count_commits_between(&self, from: Oid, to: Oid) -> Result<usize> {
339        let mut revwalk = self.inner.revwalk()?;
340        revwalk.push(to)?;
341        revwalk.hide(from)?;
342
343        Ok(revwalk.count())
344    }
345
346    /// Get commits between two points.
347    ///
348    /// # Errors
349    /// Return error if revwalk fails.
350    pub fn commits_between(&self, from: Oid, to: Oid) -> Result<Vec<Oid>> {
351        let mut revwalk = self.inner.revwalk()?;
352        revwalk.push(to)?;
353        revwalk.hide(from)?;
354
355        let mut commits = Vec::new();
356        for oid in revwalk {
357            let oid = oid?;
358            commits.push(oid);
359        }
360
361        Ok(commits)
362    }
363
364    // === Reset operations ===
365
366    /// Hard reset a branch to a specific commit.
367    ///
368    /// # Errors
369    /// Returns error if reset fails.
370    pub fn reset_branch(&self, branch_name: &str, target: Oid) -> Result<()> {
371        let commit = self.inner.find_commit(target)?;
372        let reference_name = format!("refs/heads/{branch_name}");
373
374        self.inner.reference(
375            &reference_name,
376            target,
377            true, // force
378            &format!("rung: reset to {}", &target.to_string()[..8]),
379        )?;
380
381        // If this is the current branch, also update working directory
382        if self.current_branch().ok().as_deref() == Some(branch_name) {
383            self.inner
384                .reset(commit.as_object(), git2::ResetType::Hard, None)?;
385        }
386
387        Ok(())
388    }
389
390    // === Signature ===
391
392    /// Get the default signature for commits.
393    ///
394    /// # Errors
395    /// Returns error if git config doesn't have user.name/email.
396    pub fn signature(&self) -> Result<Signature<'_>> {
397        Ok(self.inner.signature()?)
398    }
399
400    // === Rebase operations ===
401
402    /// Rebase the current branch onto a target commit.
403    ///
404    /// Returns `Ok(())` on success, or `Err(RebaseConflict)` if there are conflicts.
405    ///
406    /// # Errors
407    /// Returns error if rebase fails or conflicts occur.
408    pub fn rebase_onto(&self, target: Oid) -> Result<()> {
409        let workdir = self.workdir().ok_or(Error::NotARepository)?;
410
411        let output = std::process::Command::new("git")
412            .args(["rebase", &target.to_string()])
413            .current_dir(workdir)
414            .output()
415            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
416
417        if output.status.success() {
418            return Ok(());
419        }
420
421        // Check if it's a conflict
422        if self.is_rebasing() {
423            let conflicts = self.conflicting_files()?;
424            return Err(Error::RebaseConflict(conflicts));
425        }
426
427        let stderr = String::from_utf8_lossy(&output.stderr);
428        Err(Error::RebaseFailed(stderr.to_string()))
429    }
430
431    /// Rebase the current branch onto a new base, replaying only commits after `old_base`.
432    ///
433    /// This is equivalent to `git rebase --onto <new_base> <old_base>`.
434    /// Use this when the `old_base` was squash-merged and you want to bring only
435    /// the unique commits from the current branch.
436    ///
437    /// # Errors
438    /// Returns error if rebase fails or conflicts occur.
439    pub fn rebase_onto_from(&self, new_base: Oid, old_base: Oid) -> Result<()> {
440        let workdir = self.workdir().ok_or(Error::NotARepository)?;
441
442        let output = std::process::Command::new("git")
443            .args([
444                "rebase",
445                "--onto",
446                &new_base.to_string(),
447                &old_base.to_string(),
448            ])
449            .current_dir(workdir)
450            .output()
451            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
452
453        if output.status.success() {
454            return Ok(());
455        }
456
457        // Check if it's a conflict
458        if self.is_rebasing() {
459            let conflicts = self.conflicting_files()?;
460            return Err(Error::RebaseConflict(conflicts));
461        }
462
463        let stderr = String::from_utf8_lossy(&output.stderr);
464        Err(Error::RebaseFailed(stderr.to_string()))
465    }
466
467    /// Get list of files with conflicts.
468    ///
469    /// # Errors
470    /// Returns error if status check fails.
471    pub fn conflicting_files(&self) -> Result<Vec<String>> {
472        let statuses = self.inner.statuses(None)?;
473        let conflicts: Vec<String> = statuses
474            .iter()
475            .filter(|s| s.status().is_conflicted())
476            .filter_map(|s| s.path().map(String::from))
477            .collect();
478        Ok(conflicts)
479    }
480
481    /// Abort an in-progress rebase.
482    ///
483    /// # Errors
484    /// Returns error if abort fails.
485    pub fn rebase_abort(&self) -> Result<()> {
486        let workdir = self.workdir().ok_or(Error::NotARepository)?;
487
488        let output = std::process::Command::new("git")
489            .args(["rebase", "--abort"])
490            .current_dir(workdir)
491            .output()
492            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
493
494        if output.status.success() {
495            Ok(())
496        } else {
497            let stderr = String::from_utf8_lossy(&output.stderr);
498            Err(Error::RebaseFailed(stderr.to_string()))
499        }
500    }
501
502    /// Continue an in-progress rebase.
503    ///
504    /// # Errors
505    /// Returns error if continue fails or new conflicts occur.
506    pub fn rebase_continue(&self) -> Result<()> {
507        let workdir = self.workdir().ok_or(Error::NotARepository)?;
508
509        let output = std::process::Command::new("git")
510            .args(["rebase", "--continue"])
511            .current_dir(workdir)
512            .output()
513            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
514
515        if output.status.success() {
516            return Ok(());
517        }
518
519        // Check if it's a conflict
520        if self.is_rebasing() {
521            let conflicts = self.conflicting_files()?;
522            return Err(Error::RebaseConflict(conflicts));
523        }
524
525        let stderr = String::from_utf8_lossy(&output.stderr);
526        Err(Error::RebaseFailed(stderr.to_string()))
527    }
528
529    // === Remote operations ===
530
531    /// Get the URL of the origin remote.
532    ///
533    /// # Errors
534    /// Returns error if origin remote is not found.
535    pub fn origin_url(&self) -> Result<String> {
536        let remote = self
537            .inner
538            .find_remote("origin")
539            .map_err(|_| Error::RemoteNotFound("origin".into()))?;
540
541        remote
542            .url()
543            .map(String::from)
544            .ok_or_else(|| Error::RemoteNotFound("origin".into()))
545    }
546
547    /// Parse owner and repo name from a GitHub URL.
548    ///
549    /// Supports both HTTPS and SSH URLs:
550    /// - `https://github.com/owner/repo.git`
551    /// - `git@github.com:owner/repo.git`
552    ///
553    /// # Errors
554    /// Returns error if URL cannot be parsed.
555    pub fn parse_github_remote(url: &str) -> Result<(String, String)> {
556        // SSH format: git@github.com:owner/repo.git
557        if let Some(rest) = url.strip_prefix("git@github.com:") {
558            let path = rest.strip_suffix(".git").unwrap_or(rest);
559            if let Some((owner, repo)) = path.split_once('/') {
560                return Ok((owner.to_string(), repo.to_string()));
561            }
562        }
563
564        // HTTPS format: https://github.com/owner/repo.git
565        if let Some(rest) = url
566            .strip_prefix("https://github.com/")
567            .or_else(|| url.strip_prefix("http://github.com/"))
568        {
569            let path = rest.strip_suffix(".git").unwrap_or(rest);
570            if let Some((owner, repo)) = path.split_once('/') {
571                return Ok((owner.to_string(), repo.to_string()));
572            }
573        }
574
575        Err(Error::InvalidRemoteUrl(url.to_string()))
576    }
577
578    /// Push a branch to the remote.
579    ///
580    /// # Errors
581    /// Returns error if push fails.
582    pub fn push(&self, branch: &str, force: bool) -> Result<()> {
583        let workdir = self.workdir().ok_or(Error::NotARepository)?;
584
585        let mut args = vec!["push", "-u", "origin", branch];
586        if force {
587            args.insert(1, "--force-with-lease");
588        }
589
590        let output = std::process::Command::new("git")
591            .args(&args)
592            .current_dir(workdir)
593            .output()
594            .map_err(|e| Error::PushFailed(e.to_string()))?;
595
596        if output.status.success() {
597            Ok(())
598        } else {
599            let stderr = String::from_utf8_lossy(&output.stderr);
600            Err(Error::PushFailed(stderr.to_string()))
601        }
602    }
603
604    /// Fetch a branch from origin.
605    ///
606    /// # Errors
607    /// Returns error if fetch fails.
608    pub fn fetch(&self, branch: &str) -> Result<()> {
609        let workdir = self.workdir().ok_or(Error::NotARepository)?;
610
611        // Use refspec to update both remote tracking branch and local branch
612        // Format: origin/branch:refs/heads/branch
613        let refspec = format!("{branch}:refs/heads/{branch}");
614        let output = std::process::Command::new("git")
615            .args(["fetch", "origin", &refspec])
616            .current_dir(workdir)
617            .output()
618            .map_err(|e| Error::FetchFailed(e.to_string()))?;
619
620        if output.status.success() {
621            Ok(())
622        } else {
623            let stderr = String::from_utf8_lossy(&output.stderr);
624            Err(Error::FetchFailed(stderr.to_string()))
625        }
626    }
627
628    /// Pull (fast-forward only) the current branch from origin.
629    ///
630    /// This fetches and merges `origin/<branch>` into the current branch,
631    /// but only if it can be fast-forwarded.
632    ///
633    /// # Errors
634    /// Returns error if pull fails or fast-forward is not possible.
635    pub fn pull_ff(&self) -> Result<()> {
636        let workdir = self.workdir().ok_or(Error::NotARepository)?;
637
638        let output = std::process::Command::new("git")
639            .args(["pull", "--ff-only"])
640            .current_dir(workdir)
641            .output()
642            .map_err(|e| Error::FetchFailed(e.to_string()))?;
643
644        if output.status.success() {
645            Ok(())
646        } else {
647            let stderr = String::from_utf8_lossy(&output.stderr);
648            Err(Error::FetchFailed(stderr.to_string()))
649        }
650    }
651
652    // === Low-level access ===
653
654    /// Get a reference to the underlying git2 repository.
655    ///
656    /// Use sparingly - prefer high-level methods.
657    #[must_use]
658    pub const fn inner(&self) -> &git2::Repository {
659        &self.inner
660    }
661}
662
663impl std::fmt::Debug for Repository {
664    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665        f.debug_struct("Repository")
666            .field("path", &self.git_dir())
667            .finish()
668    }
669}
670
671#[cfg(test)]
672#[allow(clippy::unwrap_used)]
673mod tests {
674    use super::*;
675    use std::fs;
676    use tempfile::TempDir;
677
678    fn init_test_repo() -> (TempDir, Repository) {
679        let temp = TempDir::new().unwrap();
680        let repo = git2::Repository::init(temp.path()).unwrap();
681
682        // Create initial commit with owned signature (avoids borrowing repo)
683        let sig = git2::Signature::now("Test", "test@example.com").unwrap();
684        let tree_id = repo.index().unwrap().write_tree().unwrap();
685        let tree = repo.find_tree(tree_id).unwrap();
686        repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
687            .unwrap();
688        drop(tree);
689
690        let wrapped = Repository { inner: repo };
691        (temp, wrapped)
692    }
693
694    #[test]
695    fn test_current_branch() {
696        let (_temp, repo) = init_test_repo();
697        // Default branch after init
698        let branch = repo.current_branch().unwrap();
699        assert!(branch == "main" || branch == "master");
700    }
701
702    #[test]
703    fn test_create_and_checkout_branch() {
704        let (_temp, repo) = init_test_repo();
705
706        repo.create_branch("feature/test").unwrap();
707        assert!(repo.branch_exists("feature/test"));
708
709        repo.checkout("feature/test").unwrap();
710        assert_eq!(repo.current_branch().unwrap(), "feature/test");
711    }
712
713    #[test]
714    fn test_is_clean() {
715        let (temp, repo) = init_test_repo();
716
717        assert!(repo.is_clean().unwrap());
718
719        // Create and commit a tracked file
720        fs::write(temp.path().join("test.txt"), "initial").unwrap();
721        {
722            let mut index = repo.inner.index().unwrap();
723            index.add_path(std::path::Path::new("test.txt")).unwrap();
724            index.write().unwrap();
725            let tree_id = index.write_tree().unwrap();
726            let tree = repo.inner.find_tree(tree_id).unwrap();
727            let parent = repo.inner.head().unwrap().peel_to_commit().unwrap();
728            let sig = git2::Signature::now("Test", "test@example.com").unwrap();
729            repo.inner
730                .commit(Some("HEAD"), &sig, &sig, "Add test file", &tree, &[&parent])
731                .unwrap();
732        }
733
734        // Should still be clean after commit
735        assert!(repo.is_clean().unwrap());
736
737        // Modify tracked file
738        fs::write(temp.path().join("test.txt"), "modified").unwrap();
739        assert!(!repo.is_clean().unwrap());
740    }
741
742    #[test]
743    fn test_list_branches() {
744        let (_temp, repo) = init_test_repo();
745
746        repo.create_branch("feature/a").unwrap();
747        repo.create_branch("feature/b").unwrap();
748
749        let branches = repo.list_branches().unwrap();
750        assert!(branches.len() >= 3); // main/master + 2 features
751        assert!(branches.iter().any(|b| b == "feature/a"));
752        assert!(branches.iter().any(|b| b == "feature/b"));
753    }
754}