use anyhow::{anyhow, Context, Result};
use git2::{
build::CheckoutBuilder, BranchType, Oid, Repository, Signature, StashFlags, Status,
StatusOptions,
};
pub struct GitRepo {
repo: Repository,
}
impl GitRepo {
pub fn open() -> Result<Self> {
let repo = Repository::discover(".").context("Not a git repository")?;
Ok(Self { repo })
}
pub fn git_dir(&self) -> &std::path::Path {
self.repo.path()
}
pub fn workdir(&self) -> Result<&std::path::Path> {
self.repo
.workdir()
.ok_or_else(|| anyhow!("Bare repository has no working directory"))
}
pub fn get_current_branch(&self) -> Result<String> {
let head = self.repo.head().context("Failed to get HEAD")?;
if head.is_branch() {
let name = head
.shorthand()
.ok_or_else(|| anyhow!("Invalid branch name"))?;
Ok(name.to_string())
} else {
let oid = head.target().ok_or_else(|| anyhow!("No target for HEAD"))?;
Ok(format!("(detached at {})", &oid.to_string()[..7]))
}
}
pub fn has_uncommitted_changes(&self) -> Result<bool> {
let mut opts = StatusOptions::new();
opts.include_untracked(true)
.recurse_untracked_dirs(true)
.exclude_submodules(true);
let statuses = self.repo.statuses(Some(&mut opts))?;
for entry in statuses.iter() {
let status = entry.status();
if status.intersects(
Status::INDEX_NEW
| Status::INDEX_MODIFIED
| Status::INDEX_DELETED
| Status::INDEX_RENAMED
| Status::INDEX_TYPECHANGE
| Status::WT_NEW
| Status::WT_MODIFIED
| Status::WT_DELETED
| Status::WT_RENAMED
| Status::WT_TYPECHANGE,
) {
return Ok(true);
}
}
Ok(false)
}
pub fn get_changes_summary(&self) -> Result<String> {
let mut opts = StatusOptions::new();
opts.include_untracked(true)
.recurse_untracked_dirs(true)
.exclude_submodules(true);
let statuses = self.repo.statuses(Some(&mut opts))?;
let mut staged = 0;
let mut modified = 0;
let mut untracked = 0;
for entry in statuses.iter() {
let status = entry.status();
if status.intersects(
Status::INDEX_NEW
| Status::INDEX_MODIFIED
| Status::INDEX_DELETED
| Status::INDEX_RENAMED
| Status::INDEX_TYPECHANGE,
) {
staged += 1;
}
if status.intersects(
Status::WT_MODIFIED
| Status::WT_DELETED
| Status::WT_RENAMED
| Status::WT_TYPECHANGE,
) {
modified += 1;
}
if status.contains(Status::WT_NEW) {
untracked += 1;
}
}
let mut parts = Vec::new();
if staged > 0 {
parts.push(format!("{} staged", staged));
}
if modified > 0 {
parts.push(format!("{} modified", modified));
}
if untracked > 0 {
parts.push(format!("{} untracked", untracked));
}
Ok(parts.join(", "))
}
pub fn stash_save(&mut self, message: &str) -> Result<Oid> {
let signature = self.get_signature()?;
let oid = self
.repo
.stash_save(&signature, message, Some(StashFlags::INCLUDE_UNTRACKED))?;
Ok(oid)
}
pub fn stash_apply(&mut self, target_oid: Oid) -> Result<()> {
let index = self.find_stash_index(target_oid)?;
self.repo.stash_apply(index, None)?;
Ok(())
}
pub fn stash_drop(&mut self, target_oid: Oid) -> Result<()> {
let index = self.find_stash_index(target_oid)?;
self.repo.stash_drop(index)?;
Ok(())
}
fn find_stash_index(&mut self, target_oid: Oid) -> Result<usize> {
let mut found_index: Option<usize> = None;
self.repo.stash_foreach(|index, _message, oid| {
if *oid == target_oid {
found_index = Some(index);
false } else {
true }
})?;
found_index.ok_or_else(|| anyhow!("Stash not found with OID: {}", target_oid))
}
pub fn list_stashes(&mut self) -> Result<Vec<StashInfo>> {
let mut stashes = Vec::new();
self.repo.stash_foreach(|index, message, oid| {
stashes.push(StashInfo {
index,
message: message.to_string(),
oid: *oid,
});
true
})?;
Ok(stashes)
}
pub fn switch_branch(&self, branch_name: &str) -> Result<()> {
let branch = self
.repo
.find_branch(branch_name, BranchType::Local)
.with_context(|| format!("Branch '{}' not found", branch_name))?;
let reference = branch.get();
let tree = reference.peel_to_tree()?;
let mut checkout_builder = CheckoutBuilder::new();
checkout_builder.safe();
self.repo
.checkout_tree(tree.as_object(), Some(&mut checkout_builder))?;
let refname = reference
.name()
.ok_or_else(|| anyhow!("Invalid reference name"))?;
self.repo.set_head(refname)?;
Ok(())
}
pub fn list_branches(&self) -> Result<Vec<String>> {
let mut branches = Vec::new();
for branch in self.repo.branches(Some(BranchType::Local))? {
let (branch, _) = branch?;
if let Some(name) = branch.name()? {
branches.push(name.to_string());
}
}
Ok(branches)
}
fn get_signature(&self) -> Result<Signature<'static>> {
if let Ok(sig) = self.repo.signature() {
return Ok(Signature::now(
sig.name().unwrap_or("git-switch"),
sig.email().unwrap_or("git-switch@local"),
)?);
}
Ok(Signature::now("git-switch", "git-switch@local")?)
}
pub fn branch_exists(&self, name: &str) -> bool {
self.repo.find_branch(name, BranchType::Local).is_ok()
}
pub fn create_branch(&self, name: &str) -> Result<()> {
let head = self.repo.head()?;
let head_commit = head.peel_to_commit()?;
self.repo.branch(name, &head_commit, false)?;
Ok(())
}
pub fn discard_changes(&self) -> Result<()> {
let head = self.repo.head()?;
let head_commit = head.peel_to_commit()?;
let tree = head_commit.tree()?;
let mut checkout_builder = CheckoutBuilder::new();
checkout_builder.force();
checkout_builder.remove_untracked(true);
self.repo
.checkout_tree(tree.as_object(), Some(&mut checkout_builder))?;
self.repo
.reset(head_commit.as_object(), git2::ResetType::Hard, None)?;
Ok(())
}
pub fn delete_branch(&self, name: &str) -> Result<()> {
let mut branch = self
.repo
.find_branch(name, BranchType::Local)
.with_context(|| format!("Branch '{}' not found", name))?;
branch.delete()?;
Ok(())
}
pub fn fetch(&self, remote_name: &str) -> Result<()> {
let mut remote = self
.repo
.find_remote(remote_name)
.with_context(|| format!("Remote '{}' not found", remote_name))?;
let refspecs: Vec<String> = remote
.fetch_refspecs()?
.iter()
.filter_map(|s| s.map(String::from))
.collect();
let refspec_strs: Vec<&str> = refspecs.iter().map(|s| s.as_str()).collect();
remote.fetch(&refspec_strs, None, None)?;
Ok(())
}
pub fn pull(&mut self) -> Result<()> {
let head = self.repo.head()?;
if !head.is_branch() {
return Err(anyhow!("Cannot pull in detached HEAD state"));
}
let branch_name = head
.shorthand()
.ok_or_else(|| anyhow!("Invalid branch name"))?;
let branch = self.repo.find_branch(branch_name, BranchType::Local)?;
let upstream = branch.upstream().context("No upstream branch configured")?;
let upstream_name = upstream
.name()?
.ok_or_else(|| anyhow!("Invalid upstream branch name"))?;
let remote_name = upstream_name
.split('/')
.next()
.ok_or_else(|| anyhow!("Invalid upstream format"))?;
self.fetch(remote_name)?;
let upstream_ref = upstream.get();
let upstream_commit = upstream_ref.peel_to_commit()?;
let fetch_head = self.repo.find_reference("FETCH_HEAD")?;
let annotated = self.repo.reference_to_annotated_commit(&fetch_head)?;
let analysis = self.repo.merge_analysis(&[&annotated])?;
if analysis.0.is_up_to_date() {
return Ok(());
}
if analysis.0.is_fast_forward() {
let refname = head.name().ok_or_else(|| anyhow!("Invalid HEAD ref"))?;
self.repo
.reference(refname, upstream_commit.id(), true, "pull: fast-forward")?;
let mut checkout = CheckoutBuilder::new();
checkout.force();
self.repo.checkout_head(Some(&mut checkout))?;
} else {
return Err(anyhow!(
"Cannot fast-forward. Please merge or rebase manually."
));
}
Ok(())
}
pub fn get_tracking_remote(&self, branch_name: &str) -> Result<Option<String>> {
let branch = match self.repo.find_branch(branch_name, BranchType::Local) {
Ok(b) => b,
Err(_) => return Ok(None),
};
match branch.upstream() {
Ok(upstream) => {
let name = upstream.name()?.map(String::from);
Ok(name)
}
Err(_) => Ok(None),
}
}
pub fn create_tracking_branch(&self, local_name: &str, remote_ref: &str) -> Result<()> {
let reference = self
.repo
.find_reference(&format!("refs/remotes/{}", remote_ref))
.with_context(|| format!("Remote branch '{}' not found", remote_ref))?;
let commit = reference.peel_to_commit()?;
let mut branch = self.repo.branch(local_name, &commit, false)?;
branch.set_upstream(Some(remote_ref))?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct StashInfo {
pub index: usize,
pub message: String,
pub oid: Oid,
}