use crate::errors::Result;
use crate::git::GitRepository;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamInfo {
pub remote: String,
pub branch: String,
pub full_name: String, pub ahead: usize, pub behind: usize, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchInfo {
pub name: String,
pub commit_hash: String,
pub is_current: bool,
pub upstream: Option<UpstreamInfo>,
}
pub struct BranchManager {
git_repo: GitRepository,
}
impl BranchManager {
pub fn new(git_repo: GitRepository) -> Self {
Self { git_repo }
}
pub fn get_branch_info(&self) -> Result<Vec<BranchInfo>> {
let branches = self.git_repo.list_branches()?;
let current_branch = self.git_repo.get_current_branch().ok();
let mut branch_info = Vec::new();
for branch_name in branches {
let commit_hash = self.get_branch_commit_hash(&branch_name)?;
let is_current = current_branch.as_ref() == Some(&branch_name);
let upstream = self.get_upstream_info(&branch_name)?;
branch_info.push(BranchInfo {
name: branch_name,
commit_hash,
is_current,
upstream,
});
}
Ok(branch_info)
}
fn get_branch_commit_hash(&self, branch_name: &str) -> Result<String> {
self.git_repo.get_branch_commit_hash(branch_name)
}
fn get_upstream_info(&self, branch_name: &str) -> Result<Option<UpstreamInfo>> {
if let Some(upstream) = self.git_repo.get_upstream_branch(branch_name)? {
let (remote, remote_branch) = self.parse_upstream_name(&upstream)?;
let (ahead, behind) = self.calculate_ahead_behind_counts(branch_name, &upstream)?;
Ok(Some(UpstreamInfo {
remote,
branch: remote_branch,
full_name: upstream,
ahead,
behind,
}))
} else {
Ok(None)
}
}
fn parse_upstream_name(&self, upstream: &str) -> Result<(String, String)> {
let parts: Vec<&str> = upstream.splitn(2, '/').collect();
if parts.len() == 2 {
Ok((parts[0].to_string(), parts[1].to_string()))
} else {
Ok(("origin".to_string(), upstream.to_string()))
}
}
fn calculate_ahead_behind_counts(
&self,
local_branch: &str,
upstream_branch: &str,
) -> Result<(usize, usize)> {
match self
.git_repo
.get_ahead_behind_counts(local_branch, upstream_branch)
{
Ok((ahead, behind)) => Ok((ahead, behind)),
Err(_) => {
Ok((0, 0))
}
}
}
pub fn generate_branch_name(&self, message: &str) -> String {
let base_name = message
.to_lowercase()
.chars()
.map(|c| match c {
'a'..='z' | '0'..='9' => c,
_ => '-',
})
.collect::<String>()
.split('-')
.filter(|s| !s.is_empty())
.take(5) .collect::<Vec<_>>()
.join("-");
let mut counter = 1;
let mut candidate = base_name.clone();
while self.git_repo.branch_exists(&candidate) {
candidate = format!("{base_name}-{counter}");
counter += 1;
}
if candidate.chars().next().is_none_or(|c| !c.is_alphabetic()) {
candidate = format!("feature-{candidate}");
}
candidate
}
pub fn create_branch_from_message(
&self,
message: &str,
target: Option<&str>,
) -> Result<String> {
let branch_name = self.generate_branch_name(message);
self.git_repo.create_branch(&branch_name, target)?;
Ok(branch_name)
}
pub fn set_upstream(&self, branch_name: &str, remote: &str, remote_branch: &str) -> Result<()> {
self.git_repo
.set_upstream(branch_name, remote, remote_branch)
}
pub fn get_branch_upstream(&self, branch_name: &str) -> Result<Option<UpstreamInfo>> {
self.get_upstream_info(branch_name)
}
pub fn has_upstream(&self, branch_name: &str) -> Result<bool> {
Ok(self.get_upstream_info(branch_name)?.is_some())
}
pub fn git_repo(&self) -> &GitRepository {
&self.git_repo
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::git::repository::*;
use git2::{Repository, Signature};
use tempfile::TempDir;
fn create_test_branch_manager() -> (TempDir, BranchManager) {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path();
let repo = Repository::init(repo_path).unwrap();
let signature = Signature::now("Test User", "test@example.com").unwrap();
let tree_id = {
let mut index = repo.index().unwrap();
index.write_tree().unwrap()
};
let tree = repo.find_tree(tree_id).unwrap();
repo.commit(
Some("HEAD"),
&signature,
&signature,
"Initial commit",
&tree,
&[],
)
.unwrap();
let git_repo = GitRepository::open(repo_path).unwrap();
let branch_manager = BranchManager::new(git_repo);
(temp_dir, branch_manager)
}
#[test]
fn test_branch_name_generation() {
let (_temp_dir, branch_manager) = create_test_branch_manager();
assert_eq!(
branch_manager.generate_branch_name("Add user authentication"),
"add-user-authentication"
);
assert_eq!(
branch_manager.generate_branch_name("Fix bug in payment system!!!"),
"fix-bug-in-payment-system"
);
assert_eq!(
branch_manager.generate_branch_name("123 numeric start"),
"feature-123-numeric-start"
);
}
#[test]
fn test_branch_creation() {
let (_temp_dir, branch_manager) = create_test_branch_manager();
let branch_name = branch_manager
.create_branch_from_message("Add login feature", None)
.unwrap();
assert_eq!(branch_name, "add-login-feature");
assert!(branch_manager.git_repo().branch_exists(&branch_name));
}
#[test]
fn test_branch_info() {
let (_temp_dir, branch_manager) = create_test_branch_manager();
let _branch_name = branch_manager
.create_branch_from_message("Test feature", None)
.unwrap();
let branch_info = branch_manager.get_branch_info().unwrap();
assert!(!branch_info.is_empty());
assert!(branch_info.iter().any(|b| b.is_current));
assert!(branch_info.len() >= 2);
for branch in &branch_info {
assert!(branch.upstream.is_none());
}
}
#[test]
fn test_upstream_parsing() {
let (_temp_dir, branch_manager) = create_test_branch_manager();
let (remote, branch) = branch_manager
.parse_upstream_name("origin/feature-auth")
.unwrap();
assert_eq!(remote, "origin");
assert_eq!(branch, "feature-auth");
let (remote, branch) = branch_manager.parse_upstream_name("upstream/main").unwrap();
assert_eq!(remote, "upstream");
assert_eq!(branch, "main");
let (remote, branch) = branch_manager.parse_upstream_name("feature-auth").unwrap();
assert_eq!(remote, "origin");
assert_eq!(branch, "feature-auth");
}
}