rung_git/
repository.rs

1//! Repository wrapper providing high-level git operations.
2
3use std::path::Path;
4
5use git2::{BranchType, Oid, RepositoryState, Signature};
6
7use crate::error::{Error, Result};
8
9/// High-level wrapper around a git repository.
10pub struct Repository {
11    inner: git2::Repository,
12}
13
14impl Repository {
15    /// Open a repository at the given path.
16    ///
17    /// # Errors
18    /// Returns error if no repository found at path or any parent.
19    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
20        let inner = git2::Repository::discover(path)?;
21        Ok(Self { inner })
22    }
23
24    /// Open the repository containing the current directory.
25    ///
26    /// # Errors
27    /// Returns error if not inside a git repository.
28    pub fn open_current() -> Result<Self> {
29        Self::open(".")
30    }
31
32    /// Get the path to the repository root (workdir).
33    #[must_use]
34    pub fn workdir(&self) -> Option<&Path> {
35        self.inner.workdir()
36    }
37
38    /// Get the path to the .git directory.
39    #[must_use]
40    pub fn git_dir(&self) -> &Path {
41        self.inner.path()
42    }
43
44    /// Get the current repository state.
45    #[must_use]
46    pub fn state(&self) -> RepositoryState {
47        self.inner.state()
48    }
49
50    /// Check if there's a rebase in progress.
51    #[must_use]
52    pub fn is_rebasing(&self) -> bool {
53        matches!(
54            self.state(),
55            RepositoryState::Rebase
56                | RepositoryState::RebaseInteractive
57                | RepositoryState::RebaseMerge
58        )
59    }
60
61    // === Branch operations ===
62
63    /// Get the name of the current branch.
64    ///
65    /// # Errors
66    /// Returns error if HEAD is detached.
67    pub fn current_branch(&self) -> Result<String> {
68        let head = self.inner.head()?;
69        if !head.is_branch() {
70            return Err(Error::DetachedHead);
71        }
72
73        head.shorthand()
74            .map(String::from)
75            .ok_or(Error::DetachedHead)
76    }
77
78    /// Get the commit SHA for a branch.
79    ///
80    /// # Errors
81    /// Returns error if branch doesn't exist.
82    pub fn branch_commit(&self, branch_name: &str) -> Result<Oid> {
83        let branch = self
84            .inner
85            .find_branch(branch_name, BranchType::Local)
86            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
87
88        branch
89            .get()
90            .target()
91            .ok_or_else(|| Error::BranchNotFound(branch_name.into()))
92    }
93
94    /// Get the commit ID of a remote branch tip.
95    ///
96    /// # Errors
97    /// Returns error if branch not found.
98    pub fn remote_branch_commit(&self, branch_name: &str) -> Result<Oid> {
99        let ref_name = format!("refs/remotes/origin/{branch_name}");
100        let reference = self
101            .inner
102            .find_reference(&ref_name)
103            .map_err(|_| Error::BranchNotFound(format!("origin/{branch_name}")))?;
104
105        reference
106            .target()
107            .ok_or_else(|| Error::BranchNotFound(format!("origin/{branch_name}")))
108    }
109
110    /// Create a new branch at the current HEAD.
111    ///
112    /// # Errors
113    /// Returns error if branch creation fails.
114    pub fn create_branch(&self, name: &str) -> Result<Oid> {
115        let head_commit = self.inner.head()?.peel_to_commit()?;
116        let branch = self.inner.branch(name, &head_commit, false)?;
117
118        branch
119            .get()
120            .target()
121            .ok_or_else(|| Error::BranchNotFound(name.into()))
122    }
123
124    /// Checkout a branch.
125    ///
126    /// # Errors
127    /// Returns error if checkout fails.
128    pub fn checkout(&self, branch_name: &str) -> Result<()> {
129        let branch = self
130            .inner
131            .find_branch(branch_name, BranchType::Local)
132            .map_err(|_| Error::BranchNotFound(branch_name.into()))?;
133
134        let reference = branch.get();
135        let object = reference.peel(git2::ObjectType::Commit)?;
136
137        self.inner.checkout_tree(&object, None)?;
138        self.inner.set_head(&format!("refs/heads/{branch_name}"))?;
139
140        Ok(())
141    }
142
143    /// List all local branches.
144    ///
145    /// # Errors
146    /// Returns error if branch listing fails.
147    pub fn list_branches(&self) -> Result<Vec<String>> {
148        let branches = self.inner.branches(Some(BranchType::Local))?;
149
150        let names: Vec<String> = branches
151            .filter_map(std::result::Result::ok)
152            .filter_map(|(b, _)| b.name().ok().flatten().map(String::from))
153            .collect();
154
155        Ok(names)
156    }
157
158    /// Check if a branch exists.
159    #[must_use]
160    pub fn branch_exists(&self, name: &str) -> bool {
161        self.inner.find_branch(name, BranchType::Local).is_ok()
162    }
163
164    /// Delete a local branch.
165    ///
166    /// # Errors
167    /// Returns error if branch deletion fails.
168    pub fn delete_branch(&self, name: &str) -> Result<()> {
169        let mut branch = self.inner.find_branch(name, BranchType::Local)?;
170        branch.delete()?;
171        Ok(())
172    }
173
174    // === Working directory state ===
175
176    /// Check if the working directory is clean (no modified or staged files).
177    ///
178    /// Untracked files are ignored - only tracked files that have been
179    /// modified or staged count as "dirty".
180    ///
181    /// # Errors
182    /// Returns error if status check fails.
183    pub fn is_clean(&self) -> Result<bool> {
184        let mut opts = git2::StatusOptions::new();
185        opts.include_untracked(false)
186            .include_ignored(false)
187            .include_unmodified(false)
188            .exclude_submodules(true);
189        let statuses = self.inner.statuses(Some(&mut opts))?;
190
191        // Check if any status indicates modified/staged files
192        for entry in statuses.iter() {
193            let status = entry.status();
194            // These indicate actual changes to tracked files
195            if status.intersects(
196                git2::Status::INDEX_NEW
197                    | git2::Status::INDEX_MODIFIED
198                    | git2::Status::INDEX_DELETED
199                    | git2::Status::INDEX_RENAMED
200                    | git2::Status::INDEX_TYPECHANGE
201                    | git2::Status::WT_MODIFIED
202                    | git2::Status::WT_DELETED
203                    | git2::Status::WT_TYPECHANGE
204                    | git2::Status::WT_RENAMED,
205            ) {
206                return Ok(false);
207            }
208        }
209        Ok(true)
210    }
211
212    /// Ensure working directory is clean, returning error if not.
213    ///
214    /// # Errors
215    /// Returns `DirtyWorkingDirectory` if there are uncommitted changes.
216    pub fn require_clean(&self) -> Result<()> {
217        if self.is_clean()? {
218            Ok(())
219        } else {
220            Err(Error::DirtyWorkingDirectory)
221        }
222    }
223
224    // === Commit operations ===
225
226    /// Get a commit by its SHA.
227    ///
228    /// # Errors
229    /// Returns error if commit not found.
230    pub fn find_commit(&self, oid: Oid) -> Result<git2::Commit<'_>> {
231        Ok(self.inner.find_commit(oid)?)
232    }
233
234    /// Get the merge base between two commits.
235    ///
236    /// # Errors
237    /// Returns error if merge base calculation fails.
238    pub fn merge_base(&self, one: Oid, two: Oid) -> Result<Oid> {
239        Ok(self.inner.merge_base(one, two)?)
240    }
241
242    /// Count commits between two points.
243    ///
244    /// # Errors
245    /// Returns error if revwalk fails.
246    pub fn count_commits_between(&self, from: Oid, to: Oid) -> Result<usize> {
247        let mut revwalk = self.inner.revwalk()?;
248        revwalk.push(to)?;
249        revwalk.hide(from)?;
250
251        Ok(revwalk.count())
252    }
253
254    /// Get commits between two points.
255    ///
256    /// # Errors
257    /// Return error if revwalk fails.
258    pub fn commits_between(&self, from: Oid, to: Oid) -> Result<Vec<Oid>> {
259        let mut revwalk = self.inner.revwalk()?;
260        revwalk.push(to)?;
261        revwalk.hide(from)?;
262
263        let mut commits = Vec::new();
264        for oid in revwalk {
265            let oid = oid?;
266            commits.push(oid);
267        }
268
269        Ok(commits)
270    }
271
272    // === Reset operations ===
273
274    /// Hard reset a branch to a specific commit.
275    ///
276    /// # Errors
277    /// Returns error if reset fails.
278    pub fn reset_branch(&self, branch_name: &str, target: Oid) -> Result<()> {
279        let commit = self.inner.find_commit(target)?;
280        let reference_name = format!("refs/heads/{branch_name}");
281
282        self.inner.reference(
283            &reference_name,
284            target,
285            true, // force
286            &format!("rung: reset to {}", &target.to_string()[..8]),
287        )?;
288
289        // If this is the current branch, also update working directory
290        if self.current_branch().ok().as_deref() == Some(branch_name) {
291            self.inner
292                .reset(commit.as_object(), git2::ResetType::Hard, None)?;
293        }
294
295        Ok(())
296    }
297
298    // === Signature ===
299
300    /// Get the default signature for commits.
301    ///
302    /// # Errors
303    /// Returns error if git config doesn't have user.name/email.
304    pub fn signature(&self) -> Result<Signature<'_>> {
305        Ok(self.inner.signature()?)
306    }
307
308    // === Rebase operations ===
309
310    /// Rebase the current branch onto a target commit.
311    ///
312    /// Returns `Ok(())` on success, or `Err(RebaseConflict)` if there are conflicts.
313    ///
314    /// # Errors
315    /// Returns error if rebase fails or conflicts occur.
316    pub fn rebase_onto(&self, target: Oid) -> Result<()> {
317        let workdir = self.workdir().ok_or(Error::NotARepository)?;
318
319        let output = std::process::Command::new("git")
320            .args(["rebase", &target.to_string()])
321            .current_dir(workdir)
322            .output()
323            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
324
325        if output.status.success() {
326            return Ok(());
327        }
328
329        // Check if it's a conflict
330        if self.is_rebasing() {
331            let conflicts = self.conflicting_files()?;
332            return Err(Error::RebaseConflict(conflicts));
333        }
334
335        let stderr = String::from_utf8_lossy(&output.stderr);
336        Err(Error::RebaseFailed(stderr.to_string()))
337    }
338
339    /// Rebase the current branch onto a new base, replaying only commits after `old_base`.
340    ///
341    /// This is equivalent to `git rebase --onto <new_base> <old_base>`.
342    /// Use this when the `old_base` was squash-merged and you want to bring only
343    /// the unique commits from the current branch.
344    ///
345    /// # Errors
346    /// Returns error if rebase fails or conflicts occur.
347    pub fn rebase_onto_from(&self, new_base: Oid, old_base: Oid) -> Result<()> {
348        let workdir = self.workdir().ok_or(Error::NotARepository)?;
349
350        let output = std::process::Command::new("git")
351            .args([
352                "rebase",
353                "--onto",
354                &new_base.to_string(),
355                &old_base.to_string(),
356            ])
357            .current_dir(workdir)
358            .output()
359            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
360
361        if output.status.success() {
362            return Ok(());
363        }
364
365        // Check if it's a conflict
366        if self.is_rebasing() {
367            let conflicts = self.conflicting_files()?;
368            return Err(Error::RebaseConflict(conflicts));
369        }
370
371        let stderr = String::from_utf8_lossy(&output.stderr);
372        Err(Error::RebaseFailed(stderr.to_string()))
373    }
374
375    /// Get list of files with conflicts.
376    ///
377    /// # Errors
378    /// Returns error if status check fails.
379    pub fn conflicting_files(&self) -> Result<Vec<String>> {
380        let statuses = self.inner.statuses(None)?;
381        let conflicts: Vec<String> = statuses
382            .iter()
383            .filter(|s| s.status().is_conflicted())
384            .filter_map(|s| s.path().map(String::from))
385            .collect();
386        Ok(conflicts)
387    }
388
389    /// Abort an in-progress rebase.
390    ///
391    /// # Errors
392    /// Returns error if abort fails.
393    pub fn rebase_abort(&self) -> Result<()> {
394        let workdir = self.workdir().ok_or(Error::NotARepository)?;
395
396        let output = std::process::Command::new("git")
397            .args(["rebase", "--abort"])
398            .current_dir(workdir)
399            .output()
400            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
401
402        if output.status.success() {
403            Ok(())
404        } else {
405            let stderr = String::from_utf8_lossy(&output.stderr);
406            Err(Error::RebaseFailed(stderr.to_string()))
407        }
408    }
409
410    /// Continue an in-progress rebase.
411    ///
412    /// # Errors
413    /// Returns error if continue fails or new conflicts occur.
414    pub fn rebase_continue(&self) -> Result<()> {
415        let workdir = self.workdir().ok_or(Error::NotARepository)?;
416
417        let output = std::process::Command::new("git")
418            .args(["rebase", "--continue"])
419            .current_dir(workdir)
420            .output()
421            .map_err(|e| Error::RebaseFailed(e.to_string()))?;
422
423        if output.status.success() {
424            return Ok(());
425        }
426
427        // Check if it's a conflict
428        if self.is_rebasing() {
429            let conflicts = self.conflicting_files()?;
430            return Err(Error::RebaseConflict(conflicts));
431        }
432
433        let stderr = String::from_utf8_lossy(&output.stderr);
434        Err(Error::RebaseFailed(stderr.to_string()))
435    }
436
437    // === Remote operations ===
438
439    /// Get the URL of the origin remote.
440    ///
441    /// # Errors
442    /// Returns error if origin remote is not found.
443    pub fn origin_url(&self) -> Result<String> {
444        let remote = self
445            .inner
446            .find_remote("origin")
447            .map_err(|_| Error::RemoteNotFound("origin".into()))?;
448
449        remote
450            .url()
451            .map(String::from)
452            .ok_or_else(|| Error::RemoteNotFound("origin".into()))
453    }
454
455    /// Parse owner and repo name from a GitHub URL.
456    ///
457    /// Supports both HTTPS and SSH URLs:
458    /// - `https://github.com/owner/repo.git`
459    /// - `git@github.com:owner/repo.git`
460    ///
461    /// # Errors
462    /// Returns error if URL cannot be parsed.
463    pub fn parse_github_remote(url: &str) -> Result<(String, String)> {
464        // SSH format: git@github.com:owner/repo.git
465        if let Some(rest) = url.strip_prefix("git@github.com:") {
466            let path = rest.strip_suffix(".git").unwrap_or(rest);
467            if let Some((owner, repo)) = path.split_once('/') {
468                return Ok((owner.to_string(), repo.to_string()));
469            }
470        }
471
472        // HTTPS format: https://github.com/owner/repo.git
473        if let Some(rest) = url
474            .strip_prefix("https://github.com/")
475            .or_else(|| url.strip_prefix("http://github.com/"))
476        {
477            let path = rest.strip_suffix(".git").unwrap_or(rest);
478            if let Some((owner, repo)) = path.split_once('/') {
479                return Ok((owner.to_string(), repo.to_string()));
480            }
481        }
482
483        Err(Error::InvalidRemoteUrl(url.to_string()))
484    }
485
486    /// Push a branch to the remote.
487    ///
488    /// # Errors
489    /// Returns error if push fails.
490    pub fn push(&self, branch: &str, force: bool) -> Result<()> {
491        let workdir = self.workdir().ok_or(Error::NotARepository)?;
492
493        let mut args = vec!["push", "-u", "origin", branch];
494        if force {
495            args.insert(1, "--force-with-lease");
496        }
497
498        let output = std::process::Command::new("git")
499            .args(&args)
500            .current_dir(workdir)
501            .output()
502            .map_err(|e| Error::PushFailed(e.to_string()))?;
503
504        if output.status.success() {
505            Ok(())
506        } else {
507            let stderr = String::from_utf8_lossy(&output.stderr);
508            Err(Error::PushFailed(stderr.to_string()))
509        }
510    }
511
512    /// Fetch a branch from origin.
513    ///
514    /// # Errors
515    /// Returns error if fetch fails.
516    pub fn fetch(&self, branch: &str) -> Result<()> {
517        let workdir = self.workdir().ok_or(Error::NotARepository)?;
518
519        // Use refspec to update both remote tracking branch and local branch
520        // Format: origin/branch:refs/heads/branch
521        let refspec = format!("{branch}:refs/heads/{branch}");
522        let output = std::process::Command::new("git")
523            .args(["fetch", "origin", &refspec])
524            .current_dir(workdir)
525            .output()
526            .map_err(|e| Error::FetchFailed(e.to_string()))?;
527
528        if output.status.success() {
529            Ok(())
530        } else {
531            let stderr = String::from_utf8_lossy(&output.stderr);
532            Err(Error::FetchFailed(stderr.to_string()))
533        }
534    }
535
536    /// Pull (fast-forward only) the current branch from origin.
537    ///
538    /// This fetches and merges `origin/<branch>` into the current branch,
539    /// but only if it can be fast-forwarded.
540    ///
541    /// # Errors
542    /// Returns error if pull fails or fast-forward is not possible.
543    pub fn pull_ff(&self) -> Result<()> {
544        let workdir = self.workdir().ok_or(Error::NotARepository)?;
545
546        let output = std::process::Command::new("git")
547            .args(["pull", "--ff-only"])
548            .current_dir(workdir)
549            .output()
550            .map_err(|e| Error::FetchFailed(e.to_string()))?;
551
552        if output.status.success() {
553            Ok(())
554        } else {
555            let stderr = String::from_utf8_lossy(&output.stderr);
556            Err(Error::FetchFailed(stderr.to_string()))
557        }
558    }
559
560    // === Low-level access ===
561
562    /// Get a reference to the underlying git2 repository.
563    ///
564    /// Use sparingly - prefer high-level methods.
565    #[must_use]
566    pub const fn inner(&self) -> &git2::Repository {
567        &self.inner
568    }
569}
570
571impl std::fmt::Debug for Repository {
572    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573        f.debug_struct("Repository")
574            .field("path", &self.git_dir())
575            .finish()
576    }
577}
578
579#[cfg(test)]
580#[allow(clippy::unwrap_used)]
581mod tests {
582    use super::*;
583    use std::fs;
584    use tempfile::TempDir;
585
586    fn init_test_repo() -> (TempDir, Repository) {
587        let temp = TempDir::new().unwrap();
588        let repo = git2::Repository::init(temp.path()).unwrap();
589
590        // Create initial commit with owned signature (avoids borrowing repo)
591        let sig = git2::Signature::now("Test", "test@example.com").unwrap();
592        let tree_id = repo.index().unwrap().write_tree().unwrap();
593        let tree = repo.find_tree(tree_id).unwrap();
594        repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
595            .unwrap();
596        drop(tree);
597
598        let wrapped = Repository { inner: repo };
599        (temp, wrapped)
600    }
601
602    #[test]
603    fn test_current_branch() {
604        let (_temp, repo) = init_test_repo();
605        // Default branch after init
606        let branch = repo.current_branch().unwrap();
607        assert!(branch == "main" || branch == "master");
608    }
609
610    #[test]
611    fn test_create_and_checkout_branch() {
612        let (_temp, repo) = init_test_repo();
613
614        repo.create_branch("feature/test").unwrap();
615        assert!(repo.branch_exists("feature/test"));
616
617        repo.checkout("feature/test").unwrap();
618        assert_eq!(repo.current_branch().unwrap(), "feature/test");
619    }
620
621    #[test]
622    fn test_is_clean() {
623        let (temp, repo) = init_test_repo();
624
625        assert!(repo.is_clean().unwrap());
626
627        // Create and commit a tracked file
628        fs::write(temp.path().join("test.txt"), "initial").unwrap();
629        {
630            let mut index = repo.inner.index().unwrap();
631            index.add_path(std::path::Path::new("test.txt")).unwrap();
632            index.write().unwrap();
633            let tree_id = index.write_tree().unwrap();
634            let tree = repo.inner.find_tree(tree_id).unwrap();
635            let parent = repo.inner.head().unwrap().peel_to_commit().unwrap();
636            let sig = git2::Signature::now("Test", "test@example.com").unwrap();
637            repo.inner
638                .commit(Some("HEAD"), &sig, &sig, "Add test file", &tree, &[&parent])
639                .unwrap();
640        }
641
642        // Should still be clean after commit
643        assert!(repo.is_clean().unwrap());
644
645        // Modify tracked file
646        fs::write(temp.path().join("test.txt"), "modified").unwrap();
647        assert!(!repo.is_clean().unwrap());
648    }
649
650    #[test]
651    fn test_list_branches() {
652        let (_temp, repo) = init_test_repo();
653
654        repo.create_branch("feature/a").unwrap();
655        repo.create_branch("feature/b").unwrap();
656
657        let branches = repo.list_branches().unwrap();
658        assert!(branches.len() >= 3); // main/master + 2 features
659        assert!(branches.iter().any(|b| b == "feature/a"));
660        assert!(branches.iter().any(|b| b == "feature/b"));
661    }
662}