use git2::{FetchOptions, Repository};
use log::debug;
use crate::{
error::{PrError, Result},
get_remote_callbacks,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PullRequest {
pub number: u32,
pub remote: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PrMetadata {
pub number: u32,
pub title: String,
pub author: String,
pub head_ref: String,
pub base_ref: String,
pub is_fork: bool,
pub fork_owner: Option<String>,
pub fork_url: Option<String>,
}
pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
if let Some(num_str) = input.strip_prefix('#') {
return parse_number(num_str, input).map(|num| {
Some(PullRequest {
number: num,
remote: None,
})
});
}
if let Some(num_str) = input.strip_prefix("pr#") {
return parse_number(num_str, input).map(|num| {
Some(PullRequest {
number: num,
remote: None,
})
});
}
if let Some(num_str) = input.strip_prefix("pr-") {
return parse_number(num_str, input).map(|num| {
Some(PullRequest {
number: num,
remote: None,
})
});
}
if input.contains("github.com") && input.contains("/pull/") {
return parse_github_url(input);
}
if input.contains("/pull/") && input.ends_with("/head") {
return parse_remote_ref(input);
}
Ok(None)
}
fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
num_str.parse::<u32>().map_err(|_| {
PrError::InvalidReference {
input: original_input.to_string(),
}
.into()
})
}
fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
let parts: Vec<&str> = url.split('/').collect();
for (i, &part) in parts.iter().enumerate() {
if part == "pull" && i + 1 < parts.len() {
let num_str = parts[i + 1];
let number = parse_number(num_str, url)?;
return Ok(Some(PullRequest {
number,
remote: None,
}));
}
}
Err(PrError::InvalidReference {
input: url.to_string(),
}
.into())
}
fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
let parts: Vec<&str> = ref_str.split('/').collect();
if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
let num_str = parts[parts.len() - 2];
let number = parse_number(num_str, ref_str)?;
return Ok(Some(PullRequest {
number,
remote: None,
}));
}
Err(PrError::InvalidReference {
input: ref_str.to_string(),
}
.into())
}
pub fn check_gh_available() -> Result<()> {
std::process::Command::new("gh")
.arg("--version")
.output()
.map_err(|_| PrError::GhNotInstalled)?;
Ok(())
}
pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
check_gh_available()?;
let output = std::process::Command::new("gh")
.args([
"pr",
"view",
&pr_number.to_string(),
"--json",
"number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
])
.output()
.map_err(|e| PrError::GhFetchFailed {
message: e.to_string(),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PrError::GhFetchFailed {
message: stderr.to_string(),
}
.into());
}
let json_str = String::from_utf8_lossy(&output.stdout);
let json: serde_json::Value =
serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
message: e.to_string(),
})?;
let number = json["number"]
.as_u64()
.ok_or_else(|| PrError::GhJsonParseFailed {
message: "Missing 'number' field".to_string(),
})? as u32;
let title = json["title"]
.as_str()
.ok_or_else(|| PrError::GhJsonParseFailed {
message: "Missing 'title' field".to_string(),
})?
.to_string();
let author = json["author"]["login"]
.as_str()
.ok_or_else(|| PrError::GhJsonParseFailed {
message: "Missing 'author.login' field".to_string(),
})?
.to_string();
let head_ref = json["headRefName"]
.as_str()
.ok_or_else(|| PrError::GhJsonParseFailed {
message: "Missing 'headRefName' field".to_string(),
})?
.to_string();
let base_ref = json["baseRefName"]
.as_str()
.ok_or_else(|| PrError::GhJsonParseFailed {
message: "Missing 'baseRefName' field".to_string(),
})?
.to_string();
let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
let (fork_owner, fork_url) = if is_fork {
let owner = json["headRepository"]["owner"]["login"]
.as_str()
.ok_or(PrError::MissingForkOwner)?
.to_string();
let url = json["headRepository"]["url"]
.as_str()
.map(|s| s.to_string());
(Some(owner), url)
} else {
(None, None)
};
Ok(PrMetadata {
number,
title,
author,
head_ref,
base_ref,
is_fork,
fork_owner,
fork_url,
})
}
fn sanitize_for_branch_name(s: &str) -> String {
let sanitized = s
.chars()
.map(|c| match c {
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
' ' | '/' => '-',
_ => '-',
})
.collect::<String>()
.to_lowercase();
let mut result = String::new();
let mut last_was_dash = false;
for c in sanitized.chars() {
if c == '-' {
if !last_was_dash {
result.push(c);
}
last_was_dash = true;
} else {
result.push(c);
last_was_dash = false;
}
}
result.trim_matches(|c| c == '-' || c == '_').to_string()
}
pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
format
.replace("{number}", &metadata.number.to_string())
.replace("{title}", &sanitize_for_branch_name(&metadata.title))
.replace("{author}", &sanitize_for_branch_name(&metadata.author))
.replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
}
pub fn is_pr_reference(input: &str) -> bool {
parse_pr_reference(input).ok().flatten().is_some()
}
pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
let remotes = repo.remotes()?;
for name in &["upstream", "origin"] {
if remotes.iter().flatten().any(|r| r == *name) {
debug!("Using remote: {}", name);
return Ok(name.to_string());
}
}
if let Some(first_remote) = remotes.get(0) {
Ok(first_remote.to_string())
} else {
Err(PrError::NoRemoteConfigured.into())
}
}
pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
if !metadata.is_fork {
return detect_pr_remote(repo);
}
let _fork_owner = metadata
.fork_owner
.as_ref()
.ok_or(PrError::MissingForkOwner)?;
let fork_url = metadata
.fork_url
.as_ref()
.ok_or(PrError::MissingForkOwner)?;
let fork_remote_name = format!("pr-{}-fork", metadata.number);
if repo.find_remote(&fork_remote_name).is_ok() {
debug!("Fork remote {} already exists", fork_remote_name);
return Ok(fork_remote_name);
}
debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
repo.remote(&fork_remote_name, fork_url)
.map_err(|e| PrError::FetchFailed {
remote: fork_remote_name.clone(),
message: format!("Failed to add fork remote: {}", e),
})?;
Ok(fork_remote_name)
}
pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
if repo.find_reference(&branch_ref).is_ok() {
debug!("Branch ref {} already exists", branch_ref);
return Ok(());
}
debug!("Fetching branch {} from remote {}", branch, remote_name);
let refspec = format!(
"+refs/heads/{}:refs/remotes/{}/{}",
branch, remote_name, branch
);
let remote_url = repo
.find_remote(remote_name)
.ok()
.and_then(|r| r.url().map(str::to_string));
let mut fetch_options = FetchOptions::new();
fetch_options.remote_callbacks(get_remote_callbacks(remote_url.as_deref())?);
repo.find_remote(remote_name)?
.fetch(
&[refspec.as_str()],
Some(&mut fetch_options),
Some("Fetching PR branch"),
)
.map_err(|e| PrError::FetchFailed {
remote: remote_name.to_string(),
message: e.message().to_string(),
})?;
debug!("Successfully fetched branch {}", branch);
Ok(())
}
pub fn format_pr_name(format: &str, pr_number: u32) -> String {
format.replace("{number}", &pr_number.to_string())
}
pub fn prepare_pr_worktree(
repo: &Repository,
pr_number: u32,
pr_format: &str,
) -> Result<(String, String, String)> {
debug!("Preparing PR worktree for PR #{}", pr_number);
let metadata = fetch_pr_metadata(pr_number)?;
debug!(
"Fetched metadata: title='{}', author='{}', is_fork={}",
metadata.title, metadata.author, metadata.is_fork
);
let remote_name = if metadata.is_fork {
setup_fork_remote(repo, &metadata)?
} else {
detect_pr_remote(repo)?
};
fetch_branch(repo, &remote_name, &metadata.head_ref)?;
let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
debug!("Worktree name: {}", worktree_name);
let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
debug!("Remote ref: {}", remote_ref);
Ok((worktree_name, remote_ref, metadata.base_ref))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_hash_number() {
let pr = parse_pr_reference("#123").unwrap().unwrap();
assert_eq!(pr.number, 123);
assert_eq!(pr.remote, None);
}
#[test]
fn test_parse_pr_hash_number() {
let pr = parse_pr_reference("pr#456").unwrap().unwrap();
assert_eq!(pr.number, 456);
assert_eq!(pr.remote, None);
}
#[test]
fn test_parse_pr_dash_number() {
let pr = parse_pr_reference("pr-789").unwrap().unwrap();
assert_eq!(pr.number, 789);
assert_eq!(pr.remote, None);
}
#[test]
fn test_parse_github_url() {
let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
.unwrap()
.unwrap();
assert_eq!(pr.number, 999);
assert_eq!(pr.remote, None);
}
#[test]
fn test_parse_remote_ref() {
let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
assert_eq!(pr.number, 111);
assert_eq!(pr.remote, None);
}
#[test]
fn test_parse_regular_branch_name() {
let result = parse_pr_reference("my-feature-branch").unwrap();
assert!(result.is_none());
}
#[test]
fn test_parse_invalid_number() {
let result = parse_pr_reference("#abc");
assert!(result.is_err());
}
#[test]
fn test_is_pr_reference_true() {
assert!(is_pr_reference("#123"));
assert!(is_pr_reference("pr#456"));
assert!(is_pr_reference("pr-789"));
assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
}
#[test]
fn test_is_pr_reference_false() {
assert!(!is_pr_reference("my-branch"));
assert!(!is_pr_reference("feature"));
}
#[test]
fn test_format_pr_name() {
assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
assert_eq!(format_pr_name("review-{number}", 456), "review-456");
assert_eq!(format_pr_name("{number}-test", 789), "789-test");
}
#[test]
fn test_sanitize_branch_name() {
assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
assert_eq!(
sanitize_for_branch_name("Add Feature (v2)"),
"add-feature-v2"
);
assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
assert_eq!(
sanitize_for_branch_name("Fix: Authentication Issue"),
"fix-authentication-issue"
);
assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
}
#[test]
fn test_format_with_metadata() {
let metadata = PrMetadata {
number: 123,
title: "Fix Authentication Bug".to_string(),
author: "john-smith".to_string(),
head_ref: "feature/fix-auth".to_string(),
base_ref: "main".to_string(),
is_fork: false,
fork_owner: None,
fork_url: None,
};
assert_eq!(
format_pr_name_with_metadata("pr-{number}", &metadata),
"pr-123"
);
assert_eq!(
format_pr_name_with_metadata("{number}-{title}", &metadata),
"123-fix-authentication-bug"
);
assert_eq!(
format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
"john-smith/pr-123"
);
assert_eq!(
format_pr_name_with_metadata("{branch}-{number}", &metadata),
"feature-fix-auth-123"
);
}
#[test]
#[ignore]
fn test_gh_cli_available() {
check_gh_available().expect("gh CLI should be installed");
}
#[test]
#[ignore]
fn test_fetch_real_pr_metadata() {
let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
assert_eq!(metadata.number, 1);
assert!(!metadata.title.is_empty());
assert!(!metadata.author.is_empty());
}
}