rung_git/
repository.rs

1//! Repository wrapper providing high-level git operations.
2
3use std::path::Path;
4
5use git2::{BranchType, Oid, RepositoryState, Signature};
6
7use crate::error::{Error, Result};
8
9/// High-level wrapper around a git repository.
10pub struct Repository {
11    inner: git2::Repository,
12}
13
14impl Repository {
15    /// Open a repository at the given path.
16    ///
17    /// # Errors
18    /// Returns error if no repository found at path or any parent.
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let inner = git2::Repository::discover(path)?;
21        Ok(Self { inner })
22    }
23
24    /// Open the repository containing the current directory.
25    ///
26    /// # Errors
27    /// Returns error if not inside a git repository.
28    pub fn open_current() -> Result<Self> {
29        Self::open(".")
30    }
31
32    /// Get the path to the repository root (workdir).
33    #[must_use]
34    pub fn workdir(&self) -> Option<&Path> {
35        self.inner.workdir()
36    }
37
38    /// Get the path to the .git directory.
39    #[must_use]
40    pub fn git_dir(&self) -> &Path {
41        self.inner.path()
42    }
43
44    /// Get the current repository state.
45    #[must_use]
46    pub fn state(&self) -> RepositoryState {
47        self.inner.state()
48    }
49
50    /// Check if there's a rebase in progress.
51    #[must_use]
52    pub fn is_rebasing(&self) -> bool {
53        matches!(
54            self.state(),
55            RepositoryState::Rebase
56                | RepositoryState::RebaseInteractive
57                | RepositoryState::RebaseMerge
58        )
59    }
60
61    // === Branch operations ===
62
63    /// Get the name of the current branch.
64    ///
65    /// # Errors
66    /// Returns error if HEAD is detached.
67    pub fn current_branch(&self) -> Result<String> {
68        let head = self.inner.head()?;
69        if !head.is_branch() {
70            return Err(Error::DetachedHead);
71        }
72
73        head.shorthand()
74            .map(String::from)
75            .ok_or(Error::DetachedHead)
76    }
77
78    /// Get the commit SHA for a branch.
79    ///
80    /// # Errors
81    /// Returns error if branch doesn't exist.
82    pub fn branch_commit(&self, branch_name: &str) -> Result<Oid> {
83        let branch = self
84            .inner
85            .find_branch(branch_name, BranchType::Local)
86            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
87
88        branch
89            .get()
90            .target()
91            .ok_or_else(|| Error::BranchNotFound(branch_name.into()))
92    }
93
94    /// Get the commit ID of a remote branch tip.
95    ///
96    /// # Errors
97    /// Returns error if branch not found.
98    pub fn remote_branch_commit(&self, branch_name: &str) -> Result<Oid> {
99        let ref_name = format!("refs/remotes/origin/{branch_name}");
100        let reference = self
101            .inner
102            .find_reference(&ref_name)
103            .map_err(|_| Error::BranchNotFound(format!("origin/{branch_name}")))?;
104
105        reference
106            .target()
107            .ok_or_else(|| Error::BranchNotFound(format!("origin/{branch_name}")))
108    }
109
110    /// Create a new branch at the current HEAD.
111    ///
112    /// # Errors
113    /// Returns error if branch creation fails.
114    pub fn create_branch(&self, name: &str) -> Result<Oid> {
115        let head_commit = self.inner.head()?.peel_to_commit()?;
116        let branch = self.inner.branch(name, &head_commit, false)?;
117
118        branch
119            .get()
120            .target()
121            .ok_or_else(|| Error::BranchNotFound(name.into()))
122    }
123
124    /// Checkout a branch.
125    ///
126    /// # Errors
127    /// Returns error if checkout fails.
128    pub fn checkout(&self, branch_name: &str) -> Result<()> {
129        let branch = self
130            .inner
131            .find_branch(branch_name, BranchType::Local)
132            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
133
134        let reference = branch.get();
135        let object = reference.peel(git2::ObjectType::Commit)?;
136
137        self.inner.checkout_tree(&object, None)?;
138        self.inner.set_head(&format!("refs/heads/{branch_name}"))?;
139
140        Ok(())
141    }
142
143    /// List all local branches.
144    ///
145    /// # Errors
146    /// Returns error if branch listing fails.
147    pub fn list_branches(&self) -> Result<Vec<String>> {
148        let branches = self.inner.branches(Some(BranchType::Local))?;
149
150        let names: Vec<String> = branches
151            .filter_map(std::result::Result::ok)
152            .filter_map(|(b, _)| b.name().ok().flatten().map(String::from))
153            .collect();
154
155        Ok(names)
156    }
157
158    /// Check if a branch exists.
159    #[must_use]
160    pub fn branch_exists(&self, name: &str) -> bool {
161        self.inner.find_branch(name, BranchType::Local).is_ok()
162    }
163
164    /// Delete a local branch.
165    ///
166    /// # Errors
167    /// Returns error if branch deletion fails.
168    pub fn delete_branch(&self, name: &str) -> Result<()> {
169        let mut branch = self.inner.find_branch(name, BranchType::Local)?;
170        branch.delete()?;
171        Ok(())
172    }
173
174    // === Working directory state ===
175
176    /// Check if the working directory is clean (no modified or staged files).
177    ///
178    /// Untracked files are ignored - only tracked files that have been
179    /// modified or staged count as "dirty".
180    ///
181    /// # Errors
182    /// Returns error if status check fails.
183    pub fn is_clean(&self) -> Result<bool> {
184        let mut opts = git2::StatusOptions::new();
185        opts.include_untracked(false)
186            .include_ignored(false)
187            .include_unmodified(false)
188            .exclude_submodules(true);
189        let statuses = self.inner.statuses(Some(&mut opts))?;
190
191        // Check if any status indicates modified/staged files
192        for entry in statuses.iter() {
193            let status = entry.status();
194            // These indicate actual changes to tracked files
195            if status.intersects(
196                git2::Status::INDEX_NEW
197                    | git2::Status::INDEX_MODIFIED
198                    | git2::Status::INDEX_DELETED
199                    | git2::Status::INDEX_RENAMED
200                    | git2::Status::INDEX_TYPECHANGE
201                    | git2::Status::WT_MODIFIED
202                    | git2::Status::WT_DELETED
203                    | git2::Status::WT_TYPECHANGE
204                    | git2::Status::WT_RENAMED,
205            ) {
206                return Ok(false);
207            }
208        }
209        Ok(true)
210    }
211
212    /// Ensure working directory is clean, returning error if not.
213    ///
214    /// # Errors
215    /// Returns `DirtyWorkingDirectory` if there are uncommitted changes.
216    pub fn require_clean(&self) -> Result<()> {
217        if self.is_clean()? {
218            Ok(())
219        } else {
220            Err(Error::DirtyWorkingDirectory)
221        }
222    }
223
224    // === Commit operations ===
225
226    /// Get a commit by its SHA.
227    ///
228    /// # Errors
229    /// Returns error if commit not found.
230    pub fn find_commit(&self, oid: Oid) -> Result<git2::Commit<'_>> {
231        Ok(self.inner.find_commit(oid)?)
232    }
233
234    /// Get the merge base between two commits.
235    ///
236    /// # Errors
237    /// Returns error if merge base calculation fails.
238    pub fn merge_base(&self, one: Oid, two: Oid) -> Result<Oid> {
239        Ok(self.inner.merge_base(one, two)?)
240    }
241
242    /// Count commits between two points.
243    ///
244    /// # Errors
245    /// Returns error if revwalk fails.
246    pub fn count_commits_between(&self, from: Oid, to: Oid) -> Result<usize> {
247        let mut revwalk = self.inner.revwalk()?;
248        revwalk.push(to)?;
249        revwalk.hide(from)?;
250
251        Ok(revwalk.count())
252    }
253
254    // === Reset operations ===
255
256    /// Hard reset a branch to a specific commit.
257    ///
258    /// # Errors
259    /// Returns error if reset fails.
260    pub fn reset_branch(&self, branch_name: &str, target: Oid) -> Result<()> {
261        let commit = self.inner.find_commit(target)?;
262        let reference_name = format!("refs/heads/{branch_name}");
263
264        self.inner.reference(
265            &reference_name,
266            target,
267            true, // force
268            &format!("rung: reset to {}", &target.to_string()[..8]),
269        )?;
270
271        // If this is the current branch, also update working directory
272        if self.current_branch().ok().as_deref() == Some(branch_name) {
273            self.inner
274                .reset(commit.as_object(), git2::ResetType::Hard, None)?;
275        }
276
277        Ok(())
278    }
279
280    // === Signature ===
281
282    /// Get the default signature for commits.
283    ///
284    /// # Errors
285    /// Returns error if git config doesn't have user.name/email.
286    pub fn signature(&self) -> Result<Signature<'_>> {
287        Ok(self.inner.signature()?)
288    }
289
290    // === Rebase operations ===
291
292    /// Rebase the current branch onto a target commit.
293    ///
294    /// Returns `Ok(())` on success, or `Err(RebaseConflict)` if there are conflicts.
295    ///
296    /// # Errors
297    /// Returns error if rebase fails or conflicts occur.
298    pub fn rebase_onto(&self, target: Oid) -> Result<()> {
299        let workdir = self.workdir().ok_or(Error::NotARepository)?;
300
301        let output = std::process::Command::new("git")
302            .args(["rebase", &target.to_string()])
303            .current_dir(workdir)
304            .output()
305            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
306
307        if output.status.success() {
308            return Ok(());
309        }
310
311        // Check if it's a conflict
312        if self.is_rebasing() {
313            let conflicts = self.conflicting_files()?;
314            return Err(Error::RebaseConflict(conflicts));
315        }
316
317        let stderr = String::from_utf8_lossy(&output.stderr);
318        Err(Error::RebaseFailed(stderr.to_string()))
319    }
320
321    /// Rebase the current branch onto a new base, replaying only commits after `old_base`.
322    ///
323    /// This is equivalent to `git rebase --onto <new_base> <old_base>`.
324    /// Use this when the `old_base` was squash-merged and you want to bring only
325    /// the unique commits from the current branch.
326    ///
327    /// # Errors
328    /// Returns error if rebase fails or conflicts occur.
329    pub fn rebase_onto_from(&self, new_base: Oid, old_base: Oid) -> Result<()> {
330        let workdir = self.workdir().ok_or(Error::NotARepository)?;
331
332        let output = std::process::Command::new("git")
333            .args([
334                "rebase",
335                "--onto",
336                &new_base.to_string(),
337                &old_base.to_string(),
338            ])
339            .current_dir(workdir)
340            .output()
341            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
342
343        if output.status.success() {
344            return Ok(());
345        }
346
347        // Check if it's a conflict
348        if self.is_rebasing() {
349            let conflicts = self.conflicting_files()?;
350            return Err(Error::RebaseConflict(conflicts));
351        }
352
353        let stderr = String::from_utf8_lossy(&output.stderr);
354        Err(Error::RebaseFailed(stderr.to_string()))
355    }
356
357    /// Get list of files with conflicts.
358    ///
359    /// # Errors
360    /// Returns error if status check fails.
361    pub fn conflicting_files(&self) -> Result<Vec<String>> {
362        let statuses = self.inner.statuses(None)?;
363        let conflicts: Vec<String> = statuses
364            .iter()
365            .filter(|s| s.status().is_conflicted())
366            .filter_map(|s| s.path().map(String::from))
367            .collect();
368        Ok(conflicts)
369    }
370
371    /// Abort an in-progress rebase.
372    ///
373    /// # Errors
374    /// Returns error if abort fails.
375    pub fn rebase_abort(&self) -> Result<()> {
376        let workdir = self.workdir().ok_or(Error::NotARepository)?;
377
378        let output = std::process::Command::new("git")
379            .args(["rebase", "--abort"])
380            .current_dir(workdir)
381            .output()
382            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
383
384        if output.status.success() {
385            Ok(())
386        } else {
387            let stderr = String::from_utf8_lossy(&output.stderr);
388            Err(Error::RebaseFailed(stderr.to_string()))
389        }
390    }
391
392    /// Continue an in-progress rebase.
393    ///
394    /// # Errors
395    /// Returns error if continue fails or new conflicts occur.
396    pub fn rebase_continue(&self) -> Result<()> {
397        let workdir = self.workdir().ok_or(Error::NotARepository)?;
398
399        let output = std::process::Command::new("git")
400            .args(["rebase", "--continue"])
401            .current_dir(workdir)
402            .output()
403            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
404
405        if output.status.success() {
406            return Ok(());
407        }
408
409        // Check if it's a conflict
410        if self.is_rebasing() {
411            let conflicts = self.conflicting_files()?;
412            return Err(Error::RebaseConflict(conflicts));
413        }
414
415        let stderr = String::from_utf8_lossy(&output.stderr);
416        Err(Error::RebaseFailed(stderr.to_string()))
417    }
418
419    // === Remote operations ===
420
421    /// Get the URL of the origin remote.
422    ///
423    /// # Errors
424    /// Returns error if origin remote is not found.
425    pub fn origin_url(&self) -> Result<String> {
426        let remote = self
427            .inner
428            .find_remote("origin")
429            .map_err(|_| Error::RemoteNotFound("origin".into()))?;
430
431        remote
432            .url()
433            .map(String::from)
434            .ok_or_else(|| Error::RemoteNotFound("origin".into()))
435    }
436
437    /// Parse owner and repo name from a GitHub URL.
438    ///
439    /// Supports both HTTPS and SSH URLs:
440    /// - `https://github.com/owner/repo.git`
441    /// - `git@github.com:owner/repo.git`
442    ///
443    /// # Errors
444    /// Returns error if URL cannot be parsed.
445    pub fn parse_github_remote(url: &str) -> Result<(String, String)> {
446        // SSH format: git@github.com:owner/repo.git
447        if let Some(rest) = url.strip_prefix("git@github.com:") {
448            let path = rest.strip_suffix(".git").unwrap_or(rest);
449            if let Some((owner, repo)) = path.split_once('/') {
450                return Ok((owner.to_string(), repo.to_string()));
451            }
452        }
453
454        // HTTPS format: https://github.com/owner/repo.git
455        if let Some(rest) = url
456            .strip_prefix("https://github.com/")
457            .or_else(|| url.strip_prefix("http://github.com/"))
458        {
459            let path = rest.strip_suffix(".git").unwrap_or(rest);
460            if let Some((owner, repo)) = path.split_once('/') {
461                return Ok((owner.to_string(), repo.to_string()));
462            }
463        }
464
465        Err(Error::InvalidRemoteUrl(url.to_string()))
466    }
467
468    /// Push a branch to the remote.
469    ///
470    /// # Errors
471    /// Returns error if push fails.
472    pub fn push(&self, branch: &str, force: bool) -> Result<()> {
473        let workdir = self.workdir().ok_or(Error::NotARepository)?;
474
475        let mut args = vec!["push", "-u", "origin", branch];
476        if force {
477            args.insert(1, "--force-with-lease");
478        }
479
480        let output = std::process::Command::new("git")
481            .args(&args)
482            .current_dir(workdir)
483            .output()
484            .map_err(|e| Error::PushFailed(e.to_string()))?;
485
486        if output.status.success() {
487            Ok(())
488        } else {
489            let stderr = String::from_utf8_lossy(&output.stderr);
490            Err(Error::PushFailed(stderr.to_string()))
491        }
492    }
493
494    /// Fetch a branch from origin.
495    ///
496    /// # Errors
497    /// Returns error if fetch fails.
498    pub fn fetch(&self, branch: &str) -> Result<()> {
499        let workdir = self.workdir().ok_or(Error::NotARepository)?;
500
501        // Use refspec to update both remote tracking branch and local branch
502        // Format: origin/branch:refs/heads/branch
503        let refspec = format!("{branch}:refs/heads/{branch}");
504        let output = std::process::Command::new("git")
505            .args(["fetch", "origin", &refspec])
506            .current_dir(workdir)
507            .output()
508            .map_err(|e| Error::FetchFailed(e.to_string()))?;
509
510        if output.status.success() {
511            Ok(())
512        } else {
513            let stderr = String::from_utf8_lossy(&output.stderr);
514            Err(Error::FetchFailed(stderr.to_string()))
515        }
516    }
517
518    /// Pull (fast-forward only) the current branch from origin.
519    ///
520    /// This fetches and merges `origin/<branch>` into the current branch,
521    /// but only if it can be fast-forwarded.
522    ///
523    /// # Errors
524    /// Returns error if pull fails or fast-forward is not possible.
525    pub fn pull_ff(&self) -> Result<()> {
526        let workdir = self.workdir().ok_or(Error::NotARepository)?;
527
528        let output = std::process::Command::new("git")
529            .args(["pull", "--ff-only"])
530            .current_dir(workdir)
531            .output()
532            .map_err(|e| Error::FetchFailed(e.to_string()))?;
533
534        if output.status.success() {
535            Ok(())
536        } else {
537            let stderr = String::from_utf8_lossy(&output.stderr);
538            Err(Error::FetchFailed(stderr.to_string()))
539        }
540    }
541
542    // === Low-level access ===
543
544    /// Get a reference to the underlying git2 repository.
545    ///
546    /// Use sparingly - prefer high-level methods.
547    #[must_use]
548    pub const fn inner(&self) -> &git2::Repository {
549        &self.inner
550    }
551}
552
553impl std::fmt::Debug for Repository {
554    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        f.debug_struct("Repository")
556            .field("path", &self.git_dir())
557            .finish()
558    }
559}
560
561#[cfg(test)]
562#[allow(clippy::unwrap_used)]
563mod tests {
564    use super::*;
565    use std::fs;
566    use tempfile::TempDir;
567
568    fn init_test_repo() -> (TempDir, Repository) {
569        let temp = TempDir::new().unwrap();
570        let repo = git2::Repository::init(temp.path()).unwrap();
571
572        // Create initial commit with owned signature (avoids borrowing repo)
573        let sig = git2::Signature::now("Test", "test@example.com").unwrap();
574        let tree_id = repo.index().unwrap().write_tree().unwrap();
575        let tree = repo.find_tree(tree_id).unwrap();
576        repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
577            .unwrap();
578        drop(tree);
579
580        let wrapped = Repository { inner: repo };
581        (temp, wrapped)
582    }
583
584    #[test]
585    fn test_current_branch() {
586        let (_temp, repo) = init_test_repo();
587        // Default branch after init
588        let branch = repo.current_branch().unwrap();
589        assert!(branch == "main" || branch == "master");
590    }
591
592    #[test]
593    fn test_create_and_checkout_branch() {
594        let (_temp, repo) = init_test_repo();
595
596        repo.create_branch("feature/test").unwrap();
597        assert!(repo.branch_exists("feature/test"));
598
599        repo.checkout("feature/test").unwrap();
600        assert_eq!(repo.current_branch().unwrap(), "feature/test");
601    }
602
603    #[test]
604    fn test_is_clean() {
605        let (temp, repo) = init_test_repo();
606
607        assert!(repo.is_clean().unwrap());
608
609        // Create and commit a tracked file
610        fs::write(temp.path().join("test.txt"), "initial").unwrap();
611        {
612            let mut index = repo.inner.index().unwrap();
613            index.add_path(std::path::Path::new("test.txt")).unwrap();
614            index.write().unwrap();
615            let tree_id = index.write_tree().unwrap();
616            let tree = repo.inner.find_tree(tree_id).unwrap();
617            let parent = repo.inner.head().unwrap().peel_to_commit().unwrap();
618            let sig = git2::Signature::now("Test", "test@example.com").unwrap();
619            repo.inner
620                .commit(Some("HEAD"), &sig, &sig, "Add test file", &tree, &[&parent])
621                .unwrap();
622        }
623
624        // Should still be clean after commit
625        assert!(repo.is_clean().unwrap());
626
627        // Modify tracked file
628        fs::write(temp.path().join("test.txt"), "modified").unwrap();
629        assert!(!repo.is_clean().unwrap());
630    }
631
632    #[test]
633    fn test_list_branches() {
634        let (_temp, repo) = init_test_repo();
635
636        repo.create_branch("feature/a").unwrap();
637        repo.create_branch("feature/b").unwrap();
638
639        let branches = repo.list_branches().unwrap();
640        assert!(branches.len() >= 3); // main/master + 2 features
641        assert!(branches.iter().any(|b| b == "feature/a"));
642        assert!(branches.iter().any(|b| b == "feature/b"));
643    }
644}