use crate::{git, github, spinner};
use anyhow::{Context, Result, anyhow};
trait RemoteDetectionContext {
fn list_remotes(&self) -> Result<Vec<String>>;
fn branch_exists(&self, ref_name: &str) -> Result<bool>;
fn resolve_fork(&self, spec: &git::ForkBranchSpec) -> Result<ForkBranchResult>;
fn fetch_remote(&self, remote: &str) -> Result<()>;
}
struct RealRemoteDetectionContext;
impl RemoteDetectionContext for RealRemoteDetectionContext {
fn list_remotes(&self) -> Result<Vec<String>> {
git::list_remotes()
}
fn branch_exists(&self, ref_name: &str) -> Result<bool> {
git::branch_exists(ref_name)
}
fn resolve_fork(&self, spec: &git::ForkBranchSpec) -> Result<ForkBranchResult> {
resolve_fork_branch(spec)
}
fn fetch_remote(&self, remote: &str) -> Result<()> {
git::fetch_remote(remote)
}
}
fn fork_local_branch_name(owner: &str, branch: &str) -> String {
format!("{}-{}", owner, branch)
}
pub struct PrCheckoutResult {
pub local_branch: String,
pub remote_branch: String,
}
pub fn resolve_pr_ref(
pr_number: u32,
custom_branch_name: Option<&str>,
) -> Result<PrCheckoutResult> {
let pr_details = spinner::with_spinner(&format!("Fetching PR #{}", pr_number), || {
github::get_pr_details(pr_number)
})
.with_context(|| format!("Failed to fetch details for PR #{}", pr_number))?;
println!("PR #{}: {}", pr_number, pr_details.title);
println!("Author: {}", pr_details.author.login);
println!("Branch: {}", pr_details.head_ref_name);
if pr_details.state != "OPEN" {
eprintln!(
"⚠️ Warning: PR #{} is {}. Proceeding with checkout...",
pr_number, pr_details.state
);
}
if pr_details.is_draft {
eprintln!("⚠️ Warning: PR #{} is a DRAFT.", pr_number);
}
let current_repo_owner =
git::get_repo_owner().context("Failed to determine repository owner from origin remote")?;
let is_fork = pr_details.is_fork(¤t_repo_owner);
let fork_owner = &pr_details.head_repository_owner.login;
let remote_name = if is_fork {
git::ensure_fork_remote(fork_owner)?
} else {
"origin".to_string()
};
let local_branch = custom_branch_name.map(String::from).unwrap_or_else(|| {
if is_fork {
fork_local_branch_name(fork_owner, &pr_details.head_ref_name)
} else {
pr_details.head_ref_name.clone()
}
});
let remote_branch = format!("{}/{}", remote_name, pr_details.head_ref_name);
Ok(PrCheckoutResult {
local_branch,
remote_branch,
})
}
pub struct ForkBranchResult {
pub remote_ref: String,
pub template_base_name: String,
}
pub fn resolve_fork_branch(fork_spec: &git::ForkBranchSpec) -> Result<ForkBranchResult> {
if let Ok(Some(pr)) = github::find_pr_by_head_ref(&fork_spec.owner, &fork_spec.branch) {
let state_suffix = match pr.state.as_str() {
"OPEN" if pr.is_draft => " (draft)",
"OPEN" => "",
"MERGED" => " (merged)",
"CLOSED" => " (closed)",
_ => "",
};
println!("PR #{}: {}{}", pr.number, pr.title, state_suffix);
}
let remote_name = git::ensure_fork_remote(&fork_spec.owner)?;
let remote_ref = format!("{}/{}", remote_name, fork_spec.branch);
let local_branch_name = fork_local_branch_name(&fork_spec.owner, &fork_spec.branch);
Ok(ForkBranchResult {
remote_ref,
template_base_name: local_branch_name,
})
}
pub fn detect_remote_branch(
branch_name: &str,
base: Option<&str>,
) -> Result<(Option<String>, String)> {
detect_remote_branch_internal(branch_name, base, &RealRemoteDetectionContext)
}
fn detect_remote_branch_internal(
branch_name: &str,
base: Option<&str>,
ctx: &dyn RemoteDetectionContext,
) -> Result<(Option<String>, String)> {
if let Some(fork_spec) = git::parse_fork_branch_spec(branch_name) {
if base.is_some() {
return Err(anyhow!(
"Cannot use --base with 'owner:branch' syntax. \
The branch '{}' from '{}' will be used as the base.",
fork_spec.branch,
fork_spec.owner
));
}
let result = ctx.resolve_fork(&fork_spec)?;
return Ok((Some(result.remote_ref), result.template_base_name));
}
let remotes = ctx.list_remotes().context("Failed to list git remotes")?;
let detected_remote = remotes
.iter()
.find(|r| branch_name.starts_with(&format!("{}/", r)));
if let Some(remote_name) = detected_remote {
if base.is_some() {
return Err(anyhow!(
"Cannot use --base with a remote branch reference. \
The remote branch '{}' will be used as the base.",
branch_name
));
}
let spec = git::parse_remote_branch_spec(branch_name)
.context("Invalid remote branch format. Use <remote>/<branch>")?;
if spec.remote != *remote_name {
return Err(anyhow!("Mismatched remote detection"));
}
let remote_ref = format!("refs/remotes/{}", branch_name);
if !ctx.branch_exists(&remote_ref)? {
spinner::with_spinner(
&format!(
"Branch prefix matches remote '{}', verifying if it exists there...",
remote_name
),
|| {
ctx.fetch_remote(remote_name).with_context(|| {
format!(
"Failed to fetch from remote '{}'. Please check your network connection and try again.",
remote_name
)
})
},
)?;
if !ctx.branch_exists(&remote_ref)? {
eprintln!(
"Not found on '{}', creating local branch '{}'",
remote_name, branch_name
);
return Ok((None, branch_name.to_string()));
}
}
Ok((Some(branch_name.to_string()), spec.branch))
} else {
Ok((None, branch_name.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
struct MockContext {
remotes: Vec<String>,
existing_refs: HashSet<String>,
refs_available_after_fetch: HashSet<String>,
fetch_should_fail: bool,
}
impl MockContext {
fn new(remotes: &[&str], existing_refs: &[&str]) -> Self {
Self {
remotes: remotes.iter().map(|s| s.to_string()).collect(),
existing_refs: existing_refs.iter().map(|s| s.to_string()).collect(),
refs_available_after_fetch: HashSet::new(),
fetch_should_fail: false,
}
}
fn with_fetchable_refs(
remotes: &[&str],
existing_refs: &[&str],
fetchable: &[&str],
) -> Self {
Self {
remotes: remotes.iter().map(|s| s.to_string()).collect(),
existing_refs: existing_refs.iter().map(|s| s.to_string()).collect(),
refs_available_after_fetch: fetchable.iter().map(|s| s.to_string()).collect(),
fetch_should_fail: false,
}
}
fn with_failing_fetch(remotes: &[&str], existing_refs: &[&str]) -> Self {
Self {
remotes: remotes.iter().map(|s| s.to_string()).collect(),
existing_refs: existing_refs.iter().map(|s| s.to_string()).collect(),
refs_available_after_fetch: HashSet::new(),
fetch_should_fail: true,
}
}
}
impl RemoteDetectionContext for MockContext {
fn list_remotes(&self) -> Result<Vec<String>> {
Ok(self.remotes.clone())
}
fn branch_exists(&self, ref_name: &str) -> Result<bool> {
Ok(self.existing_refs.contains(ref_name)
|| self.refs_available_after_fetch.contains(ref_name))
}
fn resolve_fork(&self, spec: &git::ForkBranchSpec) -> Result<ForkBranchResult> {
Ok(ForkBranchResult {
remote_ref: format!("fork-{}/{}", spec.owner, spec.branch),
template_base_name: format!("{}-{}", spec.owner, spec.branch),
})
}
fn fetch_remote(&self, remote: &str) -> Result<()> {
if self.fetch_should_fail {
return Err(anyhow!("Network error: failed to fetch from '{}'", remote));
}
Ok(())
}
}
#[test]
fn test_simple_local_branch_no_slash() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) = detect_remote_branch_internal("feature", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "feature");
}
#[test]
fn test_local_branch_with_slash_no_remote_match() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) = detect_remote_branch_internal("feature/foo", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "feature/foo");
}
#[test]
fn test_remote_branch_exists() {
let ctx = MockContext::new(&["origin"], &["refs/remotes/origin/feature"]);
let (remote, local) = detect_remote_branch_internal("origin/feature", None, &ctx).unwrap();
assert_eq!(remote, Some("origin/feature".to_string()));
assert_eq!(local, "feature");
}
#[test]
fn test_remote_prefix_but_branch_missing_issue_28() {
let ctx = MockContext::new(&["origin", "ezh"], &[]);
let (remote, local) =
detect_remote_branch_internal("ezh/some-feature", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "ezh/some-feature");
}
#[test]
fn test_origin_branch_missing_forgot_to_fetch() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) =
detect_remote_branch_internal("origin/new-feature", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "origin/new-feature");
}
#[test]
fn test_fork_syntax_owner_colon_branch() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) = detect_remote_branch_internal("owner:branch", None, &ctx).unwrap();
assert_eq!(remote, Some("fork-owner/branch".to_string()));
assert_eq!(local, "owner-branch");
}
#[test]
fn test_fork_syntax_with_slash_in_branch() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) =
detect_remote_branch_internal("owner:feature/foo", None, &ctx).unwrap();
assert_eq!(remote, Some("fork-owner/feature/foo".to_string()));
assert_eq!(local, "owner-feature/foo");
}
#[test]
fn test_fork_syntax_avoids_main_conflict() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) = detect_remote_branch_internal("sundbp:main", None, &ctx).unwrap();
assert_eq!(remote, Some("fork-sundbp/main".to_string()));
assert_eq!(local, "sundbp-main");
}
#[test]
fn test_base_flag_with_remote_syntax_errors() {
let ctx = MockContext::new(&["origin"], &["refs/remotes/origin/feature"]);
let err = detect_remote_branch_internal("origin/feature", Some("main"), &ctx).unwrap_err();
assert!(err.to_string().contains("Cannot use --base"));
assert!(err.to_string().contains("remote branch"));
}
#[test]
fn test_base_flag_with_fork_syntax_errors() {
let ctx = MockContext::new(&["origin"], &[]);
let err = detect_remote_branch_internal("owner:branch", Some("main"), &ctx).unwrap_err();
assert!(err.to_string().contains("Cannot use --base"));
assert!(err.to_string().contains("owner:branch"));
}
#[test]
fn test_multiple_remotes_correct_match() {
let ctx = MockContext::new(
&["origin", "upstream", "fork"],
&["refs/remotes/upstream/develop"],
);
let (remote, local) =
detect_remote_branch_internal("upstream/develop", None, &ctx).unwrap();
assert_eq!(remote, Some("upstream/develop".to_string()));
assert_eq!(local, "develop");
}
#[test]
fn test_nested_slashes_in_branch_name() {
let ctx = MockContext::new(&["origin"], &[]);
let (remote, local) =
detect_remote_branch_internal("feature/sub/task", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "feature/sub/task");
}
#[test]
fn test_remote_with_nested_branch() {
let ctx = MockContext::new(&["origin"], &["refs/remotes/origin/feature/sub/task"]);
let (remote, local) =
detect_remote_branch_internal("origin/feature/sub/task", None, &ctx).unwrap();
assert_eq!(remote, Some("origin/feature/sub/task".to_string()));
assert_eq!(local, "feature/sub/task");
}
#[test]
fn test_fetch_makes_remote_branch_available() {
let ctx = MockContext::with_fetchable_refs(
&["origin"],
&[],
&["refs/remotes/origin/new-feature"],
);
let (remote, local) =
detect_remote_branch_internal("origin/new-feature", None, &ctx).unwrap();
assert_eq!(remote, Some("origin/new-feature".to_string()));
assert_eq!(local, "new-feature");
}
#[test]
fn test_fetch_succeeds_but_branch_not_found_creates_local() {
let ctx = MockContext::new(&["ezh"], &[]);
let (remote, local) = detect_remote_branch_internal("ezh/my-feature", None, &ctx).unwrap();
assert_eq!(remote, None);
assert_eq!(local, "ezh/my-feature");
}
#[test]
fn test_fetch_fails_returns_error() {
let ctx = MockContext::with_failing_fetch(&["origin"], &[]);
let err = detect_remote_branch_internal("origin/new-feature", None, &ctx).unwrap_err();
assert!(err.to_string().contains("Failed to fetch"));
assert!(err.to_string().contains("origin"));
}
}