rung_git/
repository.rs

1//! Repository wrapper providing high-level git operations.
2
3use std::path::Path;
4
5use git2::{BranchType, Oid, RepositoryState, Signature};
6
7use crate::error::{Error, Result};
8
9/// High-level wrapper around a git repository.
10pub struct Repository {
11    inner: git2::Repository,
12}
13
14impl Repository {
15    /// Open a repository at the given path.
16    ///
17    /// # Errors
18    /// Returns error if no repository found at path or any parent.
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let inner = git2::Repository::discover(path)?;
21        Ok(Self { inner })
22    }
23
24    /// Open the repository containing the current directory.
25    ///
26    /// # Errors
27    /// Returns error if not inside a git repository.
28    pub fn open_current() -> Result<Self> {
29        Self::open(".")
30    }
31
32    /// Get the path to the repository root (workdir).
33    #[must_use]
34    pub fn workdir(&self) -> Option<&Path> {
35        self.inner.workdir()
36    }
37
38    /// Get the path to the .git directory.
39    #[must_use]
40    pub fn git_dir(&self) -> &Path {
41        self.inner.path()
42    }
43
44    /// Get the current repository state.
45    #[must_use]
46    pub fn state(&self) -> RepositoryState {
47        self.inner.state()
48    }
49
50    /// Check if there's a rebase in progress.
51    #[must_use]
52    pub fn is_rebasing(&self) -> bool {
53        matches!(
54            self.state(),
55            RepositoryState::Rebase
56                | RepositoryState::RebaseInteractive
57                | RepositoryState::RebaseMerge
58        )
59    }
60
61    /// Check if HEAD is detached (not pointing at a branch).
62    ///
63    /// # Errors
64    /// Returns error if HEAD cannot be read (e.g. unborn repo).
65    pub fn head_detached(&self) -> Result<bool> {
66        let head = self.inner.head()?;
67        Ok(!head.is_branch())
68    }
69
70    // === Branch operations ===
71
72    /// Get the name of the current branch.
73    ///
74    /// # Errors
75    /// Returns error if HEAD is detached.
76    pub fn current_branch(&self) -> Result<String> {
77        let head = self.inner.head()?;
78        if !head.is_branch() {
79            return Err(Error::DetachedHead);
80        }
81
82        head.shorthand()
83            .map(String::from)
84            .ok_or(Error::DetachedHead)
85    }
86
87    /// Get the commit SHA for a branch.
88    ///
89    /// # Errors
90    /// Returns error if branch doesn't exist.
91    pub fn branch_commit(&self, branch_name: &str) -> Result<Oid> {
92        let branch = self
93            .inner
94            .find_branch(branch_name, BranchType::Local)
95            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
96
97        branch
98            .get()
99            .target()
100            .ok_or_else(|| Error::BranchNotFound(branch_name.into()))
101    }
102
103    /// Get the commit ID of a remote branch tip.
104    ///
105    /// # Errors
106    /// Returns error if branch not found.
107    pub fn remote_branch_commit(&self, branch_name: &str) -> Result<Oid> {
108        let ref_name = format!("refs/remotes/origin/{branch_name}");
109        let reference = self
110            .inner
111            .find_reference(&ref_name)
112            .map_err(|_| Error::BranchNotFound(format!("origin/{branch_name}")))?;
113
114        reference
115            .target()
116            .ok_or_else(|| Error::BranchNotFound(format!("origin/{branch_name}")))
117    }
118
119    /// Create a new branch at the current HEAD.
120    ///
121    /// # Errors
122    /// Returns error if branch creation fails.
123    pub fn create_branch(&self, name: &str) -> Result<Oid> {
124        let head_commit = self.inner.head()?.peel_to_commit()?;
125        let branch = self.inner.branch(name, &head_commit, false)?;
126
127        branch
128            .get()
129            .target()
130            .ok_or_else(|| Error::BranchNotFound(name.into()))
131    }
132
133    /// Checkout a branch.
134    ///
135    /// # Errors
136    /// Returns error if checkout fails.
137    pub fn checkout(&self, branch_name: &str) -> Result<()> {
138        let branch = self
139            .inner
140            .find_branch(branch_name, BranchType::Local)
141            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
142
143        let reference = branch.get();
144        let object = reference.peel(git2::ObjectType::Commit)?;
145
146        self.inner.checkout_tree(&object, None)?;
147        self.inner.set_head(&format!("refs/heads/{branch_name}"))?;
148
149        Ok(())
150    }
151
152    /// List all local branches.
153    ///
154    /// # Errors
155    /// Returns error if branch listing fails.
156    pub fn list_branches(&self) -> Result<Vec<String>> {
157        let branches = self.inner.branches(Some(BranchType::Local))?;
158
159        let names: Vec<String> = branches
160            .filter_map(std::result::Result::ok)
161            .filter_map(|(b, _)| b.name().ok().flatten().map(String::from))
162            .collect();
163
164        Ok(names)
165    }
166
167    /// Check if a branch exists.
168    #[must_use]
169    pub fn branch_exists(&self, name: &str) -> bool {
170        self.inner.find_branch(name, BranchType::Local).is_ok()
171    }
172
173    /// Delete a local branch.
174    ///
175    /// # Errors
176    /// Returns error if branch deletion fails.
177    pub fn delete_branch(&self, name: &str) -> Result<()> {
178        let mut branch = self.inner.find_branch(name, BranchType::Local)?;
179        branch.delete()?;
180        Ok(())
181    }
182
183    // === Working directory state ===
184
185    /// Check if the working directory is clean (no modified or staged files).
186    ///
187    /// Untracked files are ignored - only tracked files that have been
188    /// modified or staged count as "dirty".
189    ///
190    /// # Errors
191    /// Returns error if status check fails.
192    pub fn is_clean(&self) -> Result<bool> {
193        let mut opts = git2::StatusOptions::new();
194        opts.include_untracked(false)
195            .include_ignored(false)
196            .include_unmodified(false)
197            .exclude_submodules(true);
198        let statuses = self.inner.statuses(Some(&mut opts))?;
199
200        // Check if any status indicates modified/staged files
201        for entry in statuses.iter() {
202            let status = entry.status();
203            // These indicate actual changes to tracked files
204            if status.intersects(
205                git2::Status::INDEX_NEW
206                    | git2::Status::INDEX_MODIFIED
207                    | git2::Status::INDEX_DELETED
208                    | git2::Status::INDEX_RENAMED
209                    | git2::Status::INDEX_TYPECHANGE
210                    | git2::Status::WT_MODIFIED
211                    | git2::Status::WT_DELETED
212                    | git2::Status::WT_TYPECHANGE
213                    | git2::Status::WT_RENAMED,
214            ) {
215                return Ok(false);
216            }
217        }
218        Ok(true)
219    }
220
221    /// Ensure working directory is clean, returning error if not.
222    ///
223    /// # Errors
224    /// Returns `DirtyWorkingDirectory` if there are uncommitted changes.
225    pub fn require_clean(&self) -> Result<()> {
226        if self.is_clean()? {
227            Ok(())
228        } else {
229            Err(Error::DirtyWorkingDirectory)
230        }
231    }
232
233    // === Staging operations ===
234
235    /// Stage all changes (tracked and untracked files).
236    ///
237    /// Equivalent to `git add -A`.
238    ///
239    /// # Errors
240    /// Returns error if staging fails.
241    pub fn stage_all(&self) -> Result<()> {
242        let workdir = self.workdir().ok_or(Error::NotARepository)?;
243
244        let output = std::process::Command::new("git")
245            .args(["add", "-A"])
246            .current_dir(workdir)
247            .output()
248            .map_err(|e| Error::Git2(git2::Error::from_str(&e.to_string())))?;
249
250        if output.status.success() {
251            Ok(())
252        } else {
253            let stderr = String::from_utf8_lossy(&output.stderr);
254            Err(Error::Git2(git2::Error::from_str(&stderr)))
255        }
256    }
257
258    /// Check if there are staged changes ready to commit.
259    ///
260    /// # Errors
261    /// Returns error if status check fails.
262    pub fn has_staged_changes(&self) -> Result<bool> {
263        let mut opts = git2::StatusOptions::new();
264        opts.include_untracked(false)
265            .include_ignored(false)
266            .include_unmodified(false);
267        let statuses = self.inner.statuses(Some(&mut opts))?;
268
269        for entry in statuses.iter() {
270            let status = entry.status();
271            if status.intersects(
272                git2::Status::INDEX_NEW
273                    | git2::Status::INDEX_MODIFIED
274                    | git2::Status::INDEX_DELETED
275                    | git2::Status::INDEX_RENAMED
276                    | git2::Status::INDEX_TYPECHANGE,
277            ) {
278                return Ok(true);
279            }
280        }
281        Ok(false)
282    }
283
284    /// Create a commit with the given message on HEAD.
285    ///
286    /// Handles both normal commits (with parent) and initial commits (no parent).
287    ///
288    /// # Errors
289    /// Returns error if commit creation fails.
290    pub fn create_commit(&self, message: &str) -> Result<Oid> {
291        let sig = self.signature()?;
292        let mut index = self.inner.index()?;
293        let tree_id = index.write_tree()?;
294        let tree = self.inner.find_tree(tree_id)?;
295
296        // Handle initial commit case (unborn HEAD)
297        let oid = match self.inner.head().and_then(|h| h.peel_to_commit()) {
298            Ok(parent) => {
299                self.inner
300                    .commit(Some("HEAD"), &sig, &sig, message, &tree, &[&parent])?
301            }
302            Err(_) => {
303                // Initial commit - no parent
304                self.inner
305                    .commit(Some("HEAD"), &sig, &sig, message, &tree, &[])?
306            }
307        };
308
309        Ok(oid)
310    }
311
312    // === Commit operations ===
313
314    /// Get a commit by its SHA.
315    ///
316    /// # Errors
317    /// Returns error if commit not found.
318    pub fn find_commit(&self, oid: Oid) -> Result<git2::Commit<'_>> {
319        Ok(self.inner.find_commit(oid)?)
320    }
321
322    /// Get the commit message from a branch's tip commit.
323    ///
324    /// # Errors
325    /// Returns error if branch doesn't exist or has no commits.
326    pub fn branch_commit_message(&self, branch_name: &str) -> Result<String> {
327        let oid = self.branch_commit(branch_name)?;
328        let commit = self.inner.find_commit(oid)?;
329        commit
330            .message()
331            .map(String::from)
332            .ok_or_else(|| Error::Git2(git2::Error::from_str("commit has no message")))
333    }
334
335    /// Get the merge base between two commits.
336    ///
337    /// # Errors
338    /// Returns error if merge base calculation fails.
339    pub fn merge_base(&self, one: Oid, two: Oid) -> Result<Oid> {
340        Ok(self.inner.merge_base(one, two)?)
341    }
342
343    /// Count commits between two points.
344    ///
345    /// # Errors
346    /// Returns error if revwalk fails.
347    pub fn count_commits_between(&self, from: Oid, to: Oid) -> Result<usize> {
348        let mut revwalk = self.inner.revwalk()?;
349        revwalk.push(to)?;
350        revwalk.hide(from)?;
351
352        Ok(revwalk.count())
353    }
354
355    /// Get commits between two points.
356    ///
357    /// # Errors
358    /// Return error if revwalk fails.
359    pub fn commits_between(&self, from: Oid, to: Oid) -> Result<Vec<Oid>> {
360        let mut revwalk = self.inner.revwalk()?;
361        revwalk.push(to)?;
362        revwalk.hide(from)?;
363
364        let mut commits = Vec::new();
365        for oid in revwalk {
366            let oid = oid?;
367            commits.push(oid);
368        }
369
370        Ok(commits)
371    }
372
373    // === Reset operations ===
374
375    /// Hard reset a branch to a specific commit.
376    ///
377    /// # Errors
378    /// Returns error if reset fails.
379    pub fn reset_branch(&self, branch_name: &str, target: Oid) -> Result<()> {
380        let commit = self.inner.find_commit(target)?;
381        let reference_name = format!("refs/heads/{branch_name}");
382
383        self.inner.reference(
384            &reference_name,
385            target,
386            true, // force
387            &format!("rung: reset to {}", &target.to_string()[..8]),
388        )?;
389
390        // If this is the current branch, also update working directory
391        if self.current_branch().ok().as_deref() == Some(branch_name) {
392            self.inner
393                .reset(commit.as_object(), git2::ResetType::Hard, None)?;
394        }
395
396        Ok(())
397    }
398
399    // === Signature ===
400
401    /// Get the default signature for commits.
402    ///
403    /// # Errors
404    /// Returns error if git config doesn't have user.name/email.
405    pub fn signature(&self) -> Result<Signature<'_>> {
406        Ok(self.inner.signature()?)
407    }
408
409    // === Rebase operations ===
410
411    /// Rebase the current branch onto a target commit.
412    ///
413    /// Returns `Ok(())` on success, or `Err(RebaseConflict)` if there are conflicts.
414    ///
415    /// # Errors
416    /// Returns error if rebase fails or conflicts occur.
417    pub fn rebase_onto(&self, target: Oid) -> Result<()> {
418        let workdir = self.workdir().ok_or(Error::NotARepository)?;
419
420        let output = std::process::Command::new("git")
421            .args(["rebase", &target.to_string()])
422            .current_dir(workdir)
423            .output()
424            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
425
426        if output.status.success() {
427            return Ok(());
428        }
429
430        // Check if it's a conflict
431        if self.is_rebasing() {
432            let conflicts = self.conflicting_files()?;
433            return Err(Error::RebaseConflict(conflicts));
434        }
435
436        let stderr = String::from_utf8_lossy(&output.stderr);
437        Err(Error::RebaseFailed(stderr.to_string()))
438    }
439
440    /// Rebase the current branch onto a new base, replaying only commits after `old_base`.
441    ///
442    /// This is equivalent to `git rebase --onto <new_base> <old_base>`.
443    /// Use this when the `old_base` was squash-merged and you want to bring only
444    /// the unique commits from the current branch.
445    ///
446    /// # Errors
447    /// Returns error if rebase fails or conflicts occur.
448    pub fn rebase_onto_from(&self, new_base: Oid, old_base: Oid) -> Result<()> {
449        let workdir = self.workdir().ok_or(Error::NotARepository)?;
450
451        let output = std::process::Command::new("git")
452            .args([
453                "rebase",
454                "--onto",
455                &new_base.to_string(),
456                &old_base.to_string(),
457            ])
458            .current_dir(workdir)
459            .output()
460            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
461
462        if output.status.success() {
463            return Ok(());
464        }
465
466        // Check if it's a conflict
467        if self.is_rebasing() {
468            let conflicts = self.conflicting_files()?;
469            return Err(Error::RebaseConflict(conflicts));
470        }
471
472        let stderr = String::from_utf8_lossy(&output.stderr);
473        Err(Error::RebaseFailed(stderr.to_string()))
474    }
475
476    /// Get list of files with conflicts.
477    ///
478    /// # Errors
479    /// Returns error if status check fails.
480    pub fn conflicting_files(&self) -> Result<Vec<String>> {
481        let statuses = self.inner.statuses(None)?;
482        let conflicts: Vec<String> = statuses
483            .iter()
484            .filter(|s| s.status().is_conflicted())
485            .filter_map(|s| s.path().map(String::from))
486            .collect();
487        Ok(conflicts)
488    }
489
490    /// Abort an in-progress rebase.
491    ///
492    /// # Errors
493    /// Returns error if abort fails.
494    pub fn rebase_abort(&self) -> Result<()> {
495        let workdir = self.workdir().ok_or(Error::NotARepository)?;
496
497        let output = std::process::Command::new("git")
498            .args(["rebase", "--abort"])
499            .current_dir(workdir)
500            .output()
501            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
502
503        if output.status.success() {
504            Ok(())
505        } else {
506            let stderr = String::from_utf8_lossy(&output.stderr);
507            Err(Error::RebaseFailed(stderr.to_string()))
508        }
509    }
510
511    /// Continue an in-progress rebase.
512    ///
513    /// # Errors
514    /// Returns error if continue fails or new conflicts occur.
515    pub fn rebase_continue(&self) -> Result<()> {
516        let workdir = self.workdir().ok_or(Error::NotARepository)?;
517
518        let output = std::process::Command::new("git")
519            .args(["rebase", "--continue"])
520            .current_dir(workdir)
521            .output()
522            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
523
524        if output.status.success() {
525            return Ok(());
526        }
527
528        // Check if it's a conflict
529        if self.is_rebasing() {
530            let conflicts = self.conflicting_files()?;
531            return Err(Error::RebaseConflict(conflicts));
532        }
533
534        let stderr = String::from_utf8_lossy(&output.stderr);
535        Err(Error::RebaseFailed(stderr.to_string()))
536    }
537
538    // === Remote operations ===
539
540    /// Get the URL of the origin remote.
541    ///
542    /// # Errors
543    /// Returns error if origin remote is not found.
544    pub fn origin_url(&self) -> Result<String> {
545        let remote = self
546            .inner
547            .find_remote("origin")
548            .map_err(|_| Error::RemoteNotFound("origin".into()))?;
549
550        remote
551            .url()
552            .map(String::from)
553            .ok_or_else(|| Error::RemoteNotFound("origin".into()))
554    }
555
556    /// Parse owner and repo name from a GitHub URL.
557    ///
558    /// Supports both HTTPS and SSH URLs:
559    /// - `https://github.com/owner/repo.git`
560    /// - `git@github.com:owner/repo.git`
561    ///
562    /// # Errors
563    /// Returns error if URL cannot be parsed.
564    pub fn parse_github_remote(url: &str) -> Result<(String, String)> {
565        // SSH format: git@github.com:owner/repo.git
566        if let Some(rest) = url.strip_prefix("git@github.com:") {
567            let path = rest.strip_suffix(".git").unwrap_or(rest);
568            if let Some((owner, repo)) = path.split_once('/') {
569                return Ok((owner.to_string(), repo.to_string()));
570            }
571        }
572
573        // HTTPS format: https://github.com/owner/repo.git
574        if let Some(rest) = url
575            .strip_prefix("https://github.com/")
576            .or_else(|| url.strip_prefix("http://github.com/"))
577        {
578            let path = rest.strip_suffix(".git").unwrap_or(rest);
579            if let Some((owner, repo)) = path.split_once('/') {
580                return Ok((owner.to_string(), repo.to_string()));
581            }
582        }
583
584        Err(Error::InvalidRemoteUrl(url.to_string()))
585    }
586
587    /// Push a branch to the remote.
588    ///
589    /// # Errors
590    /// Returns error if push fails.
591    pub fn push(&self, branch: &str, force: bool) -> Result<()> {
592        let workdir = self.workdir().ok_or(Error::NotARepository)?;
593
594        let mut args = vec!["push", "-u", "origin", branch];
595        if force {
596            args.insert(1, "--force-with-lease");
597        }
598
599        let output = std::process::Command::new("git")
600            .args(&args)
601            .current_dir(workdir)
602            .output()
603            .map_err(|e| Error::PushFailed(e.to_string()))?;
604
605        if output.status.success() {
606            Ok(())
607        } else {
608            let stderr = String::from_utf8_lossy(&output.stderr);
609            Err(Error::PushFailed(stderr.to_string()))
610        }
611    }
612
613    /// Fetch a branch from origin.
614    ///
615    /// # Errors
616    /// Returns error if fetch fails.
617    pub fn fetch(&self, branch: &str) -> Result<()> {
618        let workdir = self.workdir().ok_or(Error::NotARepository)?;
619
620        // Use refspec to update both remote tracking branch and local branch
621        // Format: origin/branch:refs/heads/branch
622        let refspec = format!("{branch}:refs/heads/{branch}");
623        let output = std::process::Command::new("git")
624            .args(["fetch", "origin", &refspec])
625            .current_dir(workdir)
626            .output()
627            .map_err(|e| Error::FetchFailed(e.to_string()))?;
628
629        if output.status.success() {
630            Ok(())
631        } else {
632            let stderr = String::from_utf8_lossy(&output.stderr);
633            Err(Error::FetchFailed(stderr.to_string()))
634        }
635    }
636
637    /// Pull (fast-forward only) the current branch from origin.
638    ///
639    /// This fetches and merges `origin/<branch>` into the current branch,
640    /// but only if it can be fast-forwarded.
641    ///
642    /// # Errors
643    /// Returns error if pull fails or fast-forward is not possible.
644    pub fn pull_ff(&self) -> Result<()> {
645        let workdir = self.workdir().ok_or(Error::NotARepository)?;
646
647        let output = std::process::Command::new("git")
648            .args(["pull", "--ff-only"])
649            .current_dir(workdir)
650            .output()
651            .map_err(|e| Error::FetchFailed(e.to_string()))?;
652
653        if output.status.success() {
654            Ok(())
655        } else {
656            let stderr = String::from_utf8_lossy(&output.stderr);
657            Err(Error::FetchFailed(stderr.to_string()))
658        }
659    }
660
661    // === Low-level access ===
662
663    /// Get a reference to the underlying git2 repository.
664    ///
665    /// Use sparingly - prefer high-level methods.
666    #[must_use]
667    pub const fn inner(&self) -> &git2::Repository {
668        &self.inner
669    }
670}
671
672impl std::fmt::Debug for Repository {
673    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
674        f.debug_struct("Repository")
675            .field("path", &self.git_dir())
676            .finish()
677    }
678}
679
680#[cfg(test)]
681#[allow(clippy::unwrap_used)]
682mod tests {
683    use super::*;
684    use std::fs;
685    use tempfile::TempDir;
686
687    fn init_test_repo() -> (TempDir, Repository) {
688        let temp = TempDir::new().unwrap();
689        let repo = git2::Repository::init(temp.path()).unwrap();
690
691        // Create initial commit with owned signature (avoids borrowing repo)
692        let sig = git2::Signature::now("Test", "test@example.com").unwrap();
693        let tree_id = repo.index().unwrap().write_tree().unwrap();
694        let tree = repo.find_tree(tree_id).unwrap();
695        repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
696            .unwrap();
697        drop(tree);
698
699        let wrapped = Repository { inner: repo };
700        (temp, wrapped)
701    }
702
703    #[test]
704    fn test_current_branch() {
705        let (_temp, repo) = init_test_repo();
706        // Default branch after init
707        let branch = repo.current_branch().unwrap();
708        assert!(branch == "main" || branch == "master");
709    }
710
711    #[test]
712    fn test_create_and_checkout_branch() {
713        let (_temp, repo) = init_test_repo();
714
715        repo.create_branch("feature/test").unwrap();
716        assert!(repo.branch_exists("feature/test"));
717
718        repo.checkout("feature/test").unwrap();
719        assert_eq!(repo.current_branch().unwrap(), "feature/test");
720    }
721
722    #[test]
723    fn test_is_clean() {
724        let (temp, repo) = init_test_repo();
725
726        assert!(repo.is_clean().unwrap());
727
728        // Create and commit a tracked file
729        fs::write(temp.path().join("test.txt"), "initial").unwrap();
730        {
731            let mut index = repo.inner.index().unwrap();
732            index.add_path(std::path::Path::new("test.txt")).unwrap();
733            index.write().unwrap();
734            let tree_id = index.write_tree().unwrap();
735            let tree = repo.inner.find_tree(tree_id).unwrap();
736            let parent = repo.inner.head().unwrap().peel_to_commit().unwrap();
737            let sig = git2::Signature::now("Test", "test@example.com").unwrap();
738            repo.inner
739                .commit(Some("HEAD"), &sig, &sig, "Add test file", &tree, &[&parent])
740                .unwrap();
741        }
742
743        // Should still be clean after commit
744        assert!(repo.is_clean().unwrap());
745
746        // Modify tracked file
747        fs::write(temp.path().join("test.txt"), "modified").unwrap();
748        assert!(!repo.is_clean().unwrap());
749    }
750
751    #[test]
752    fn test_list_branches() {
753        let (_temp, repo) = init_test_repo();
754
755        repo.create_branch("feature/a").unwrap();
756        repo.create_branch("feature/b").unwrap();
757
758        let branches = repo.list_branches().unwrap();
759        assert!(branches.len() >= 3); // main/master + 2 features
760        assert!(branches.iter().any(|b| b == "feature/a"));
761        assert!(branches.iter().any(|b| b == "feature/b"));
762    }
763}