Skip to main content

workon/
pr.rs

1//! Pull request support for creating worktrees from PR references.
2//!
3//! This module enables creating worktrees directly from pull request references,
4//! making it easy to review PRs in isolated worktrees.
5//!
6//! ## PR Reference Parsing
7//!
8//! Supports multiple PR reference formats:
9//! - `#123` - GitHub shorthand (most common)
10//! - `pr#123` or `pr-123` - Explicit PR references
11//! - `https://github.com/owner/repo/pull/123` - Full GitHub PR URL
12//! - `origin/pull/123/head` - Direct remote ref (less common)
13//!
14//! Parsing is lenient - if it looks like a PR reference, we'll try to extract the number.
15//!
16//! ## Smart Routing
17//!
18//! The CLI's smart routing (in main.rs) automatically detects PR references:
19//! ```bash
20//! git workon #123        # Routes to `new` command with PR reference
21//! git workon pr#123      # Same - creates PR worktree
22//! git workon feature     # Routes to `find` command (not a PR)
23//! ```
24//!
25//! ## Remote Detection Algorithm
26//!
27//! To fetch PRs, we need to determine which remote to use. The detection strategy:
28//! 1. Check for `upstream` remote (common in fork workflows)
29//! 2. Fall back to `origin` remote (most common)
30//! 3. Use first available remote (rare, but handles edge cases)
31//!
32//! This handles both direct repository workflows and fork-based workflows.
33//!
34//! ## Auto-Fetch Strategy
35//!
36//! PR branches are fetched automatically using gh CLI metadata:
37//! ```text
38//! git fetch <remote> +refs/heads/{branch}:refs/remotes/<remote>/{branch}
39//! ```
40//!
41//! Where `{branch}` is the actual branch name from the PR (obtained via gh CLI).
42//! The `+` forces the fetch even if not fast-forward, ensuring we always get the latest PR state.
43//!
44//! For fork PRs, a fork remote is automatically added and the branch is fetched from it.
45//! For non-fork PRs, the branch is fetched from the detected remote (origin/upstream).
46//!
47//! ## Worktree Naming
48//!
49//! Worktree names are generated from `workon.prFormat` config (default: `pr-{number}`):
50//! - `pr-123` (default format)
51//! - `#123` (if configured with `#{number}`)
52//! - `pull-123` (if configured with `pull-{number}`)
53//!
54//! The format must contain `{number}` placeholder.
55//!
56//! ## Example Usage
57//!
58//! ```bash
59//! # Create worktree for PR #123 (auto-detects remote, auto-fetches)
60//! git workon #123
61//!
62//! # Explicit PR reference
63//! git workon new pr#456
64//!
65//! # From GitHub URL
66//! git workon new https://github.com/user/repo/pull/789
67//!
68//! # Configure custom naming
69//! git config workon.prFormat "review-{number}"
70//! git workon #123  # Creates worktree named "review-123"
71//! ```
72//!
73//! ## gh CLI Integration
74//!
75//! PR support integrates with gh CLI for rich metadata:
76//! - **Format placeholders**: {number}, {title}, {author}, {branch}
77//! - **Fork support**: Auto-adds fork remotes and fetches fork branches
78//! - **Metadata**: Fetches PR title, author, branch names, and state
79//! - **Validation**: Checks PR exists before creating worktree
80
81use git2::{FetchOptions, Repository};
82use log::debug;
83
84use crate::{
85    error::{PrError, Result},
86    get_remote_callbacks,
87};
88
89/// A parsed pull request reference from user input.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct PullRequest {
92    /// The PR number extracted from the reference string.
93    pub number: u32,
94    /// Optional remote name if the reference included one (e.g. `origin/pull/123/head`).
95    pub remote: Option<String>,
96}
97
98/// PR metadata fetched from the `gh` CLI.
99#[derive(Debug, Clone)]
100pub struct PrMetadata {
101    /// PR number.
102    pub number: u32,
103    /// PR title.
104    pub title: String,
105    /// GitHub login of the PR author.
106    pub author: String,
107    /// Name of the branch that the PR was created from.
108    pub head_ref: String,
109    /// Name of the branch the PR targets.
110    pub base_ref: String,
111    /// True if the PR comes from a forked repository.
112    pub is_fork: bool,
113    /// GitHub login of the fork owner, if this is a fork PR.
114    pub fork_owner: Option<String>,
115    /// Clone URL of the fork repository, if this is a fork PR.
116    pub fork_url: Option<String>,
117}
118
119/// Parse a PR reference from user input
120///
121/// Supported formats:
122/// - `#123` - GitHub shorthand
123/// - `pr#123` or `pr-123` - Explicit PR references
124/// - `https://github.com/owner/repo/pull/123` - GitHub PR URL
125/// - `origin/pull/123/head` - Direct remote ref
126///
127/// Returns `Ok(None)` if the input is not a PR reference.
128/// Returns `Ok(Some(PullRequest))` if successfully parsed.
129/// Returns `Err` if the input looks like a PR reference but is malformed.
130pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
131    // Try #123 format
132    if let Some(num_str) = input.strip_prefix('#') {
133        return parse_number(num_str, input).map(|num| {
134            Some(PullRequest {
135                number: num,
136                remote: None,
137            })
138        });
139    }
140
141    // Try pr#123 format
142    if let Some(num_str) = input.strip_prefix("pr#") {
143        return parse_number(num_str, input).map(|num| {
144            Some(PullRequest {
145                number: num,
146                remote: None,
147            })
148        });
149    }
150
151    // Try pr-123 format
152    if let Some(num_str) = input.strip_prefix("pr-") {
153        return parse_number(num_str, input).map(|num| {
154            Some(PullRequest {
155                number: num,
156                remote: None,
157            })
158        });
159    }
160
161    // Try GitHub URL: https://github.com/owner/repo/pull/123
162    if input.contains("github.com") && input.contains("/pull/") {
163        return parse_github_url(input);
164    }
165
166    // Try remote ref format: origin/pull/123/head
167    if input.contains("/pull/") && input.ends_with("/head") {
168        return parse_remote_ref(input);
169    }
170
171    // Not a PR reference
172    Ok(None)
173}
174
175/// Helper to parse a number string
176fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
177    num_str.parse::<u32>().map_err(|_| {
178        PrError::InvalidReference {
179            input: original_input.to_string(),
180        }
181        .into()
182    })
183}
184
185/// Parse GitHub PR URL
186fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
187    // Extract the PR number from URL like: https://github.com/owner/repo/pull/123
188    let parts: Vec<&str> = url.split('/').collect();
189
190    // Find "pull" in the path and get the number after it
191    for (i, &part) in parts.iter().enumerate() {
192        if part == "pull" && i + 1 < parts.len() {
193            let num_str = parts[i + 1];
194            let number = parse_number(num_str, url)?;
195            return Ok(Some(PullRequest {
196                number,
197                remote: None,
198            }));
199        }
200    }
201
202    Err(PrError::InvalidReference {
203        input: url.to_string(),
204    }
205    .into())
206}
207
208/// Parse remote ref format: origin/pull/123/head
209fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
210    // Format: remote/pull/number/head
211    let parts: Vec<&str> = ref_str.split('/').collect();
212
213    if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
214        let num_str = parts[parts.len() - 2];
215        let number = parse_number(num_str, ref_str)?;
216        return Ok(Some(PullRequest {
217            number,
218            remote: None,
219        }));
220    }
221
222    Err(PrError::InvalidReference {
223        input: ref_str.to_string(),
224    }
225    .into())
226}
227
228/// Return `Ok(())` if the `gh` CLI is installed and reachable in `PATH`.
229///
230/// Returns [`PrError::GhNotInstalled`] if `gh` cannot be executed.
231pub fn check_gh_available() -> Result<()> {
232    std::process::Command::new("gh")
233        .arg("--version")
234        .output()
235        .map_err(|_| PrError::GhNotInstalled)?;
236    Ok(())
237}
238
239/// Fetch PR metadata for `pr_number` using the `gh` CLI.
240///
241/// Runs `gh pr view <pr_number> --json ...` and parses the JSON output.
242/// Requires `gh` to be authenticated (`gh auth login`).
243pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
244    // Ensure gh is available
245    check_gh_available()?;
246
247    // Fetch PR metadata with single gh command
248    let output = std::process::Command::new("gh")
249        .args([
250            "pr",
251            "view",
252            &pr_number.to_string(),
253            "--json",
254            "number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
255        ])
256        .output()
257        .map_err(|e| PrError::GhFetchFailed {
258            message: e.to_string(),
259        })?;
260
261    if !output.status.success() {
262        let stderr = String::from_utf8_lossy(&output.stderr);
263        return Err(PrError::GhFetchFailed {
264            message: stderr.to_string(),
265        }
266        .into());
267    }
268
269    // Parse JSON response
270    let json_str = String::from_utf8_lossy(&output.stdout);
271    let json: serde_json::Value =
272        serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
273            message: e.to_string(),
274        })?;
275
276    // Extract fields
277    let number = json["number"]
278        .as_u64()
279        .ok_or_else(|| PrError::GhJsonParseFailed {
280            message: "Missing 'number' field".to_string(),
281        })? as u32;
282
283    let title = json["title"]
284        .as_str()
285        .ok_or_else(|| PrError::GhJsonParseFailed {
286            message: "Missing 'title' field".to_string(),
287        })?
288        .to_string();
289
290    let author = json["author"]["login"]
291        .as_str()
292        .ok_or_else(|| PrError::GhJsonParseFailed {
293            message: "Missing 'author.login' field".to_string(),
294        })?
295        .to_string();
296
297    let head_ref = json["headRefName"]
298        .as_str()
299        .ok_or_else(|| PrError::GhJsonParseFailed {
300            message: "Missing 'headRefName' field".to_string(),
301        })?
302        .to_string();
303
304    let base_ref = json["baseRefName"]
305        .as_str()
306        .ok_or_else(|| PrError::GhJsonParseFailed {
307            message: "Missing 'baseRefName' field".to_string(),
308        })?
309        .to_string();
310
311    let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
312
313    let (fork_owner, fork_url) = if is_fork {
314        let owner = json["headRepository"]["owner"]["login"]
315            .as_str()
316            .ok_or(PrError::MissingForkOwner)?
317            .to_string();
318        let url = json["headRepository"]["url"]
319            .as_str()
320            .map(|s| s.to_string());
321        (Some(owner), url)
322    } else {
323        (None, None)
324    };
325
326    Ok(PrMetadata {
327        number,
328        title,
329        author,
330        head_ref,
331        base_ref,
332        is_fork,
333        fork_owner,
334        fork_url,
335    })
336}
337
338/// Sanitize a string for use in branch/worktree names
339fn sanitize_for_branch_name(s: &str) -> String {
340    let sanitized = s
341        .chars()
342        .map(|c| match c {
343            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
344            ' ' | '/' => '-',
345            _ => '-',
346        })
347        .collect::<String>()
348        .to_lowercase();
349
350    // Collapse multiple dashes into single dash
351    let mut result = String::new();
352    let mut last_was_dash = false;
353    for c in sanitized.chars() {
354        if c == '-' {
355            if !last_was_dash {
356                result.push(c);
357            }
358            last_was_dash = true;
359        } else {
360            result.push(c);
361            last_was_dash = false;
362        }
363    }
364
365    result.trim_matches(|c| c == '-' || c == '_').to_string()
366}
367
368/// Expand all placeholders in `format` using `metadata`.
369///
370/// Supported placeholders: `{number}`, `{title}`, `{author}`, `{branch}`.
371/// Title, author, and branch values are sanitized for use in branch/directory names.
372pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
373    format
374        .replace("{number}", &metadata.number.to_string())
375        .replace("{title}", &sanitize_for_branch_name(&metadata.title))
376        .replace("{author}", &sanitize_for_branch_name(&metadata.author))
377        .replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
378}
379
380/// Check if a string looks like a PR reference
381///
382/// This is a quick check used for routing decisions.
383pub fn is_pr_reference(input: &str) -> bool {
384    parse_pr_reference(input).ok().flatten().is_some()
385}
386
387/// Priority tier for the shared `upstream → origin → others` remote precedence.
388///
389/// The single encoding of the precedence ADR-024 prescribes for every remote
390/// decision: [`preferred_remote_order`] sorts by it, and
391/// `resolve_remote_tracking` (worktree.rs) uses equal tiers to detect
392/// ambiguity. Lower is more preferred; all non-special remotes share a tier.
393pub fn remote_priority(remote: &str) -> usize {
394    match remote {
395        "upstream" => 0,
396        "origin" => 1,
397        _ => 2,
398    }
399}
400
401/// Returns remotes in preferred order: upstream first, then origin, then all
402/// others in configuration order (the sort is stable).
403pub fn preferred_remote_order(repo: &Repository) -> Vec<String> {
404    let Ok(remotes) = repo.remotes() else {
405        return vec![];
406    };
407    let mut all: Vec<String> = remotes
408        .iter()
409        .flatten()
410        .flatten()
411        .map(str::to_string)
412        .collect();
413    all.sort_by_key(|r| remote_priority(r));
414    all
415}
416
417/// Select which remote to use for fetching PR refs.
418///
419/// Priority: `upstream` → `origin` → first available remote.
420/// Returns [`PrError::NoRemoteConfigured`] if the repository has no remotes.
421pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
422    preferred_remote_order(repo)
423        .into_iter()
424        .next()
425        .ok_or_else(|| PrError::NoRemoteConfigured.into())
426}
427
428/// Ensure a remote for a fork PR exists, then return its name.
429///
430/// For non-fork PRs this is equivalent to [`detect_pr_remote`].
431/// For fork PRs, a remote named `pr-{number}-fork` is added if it doesn't
432/// already exist, pointing at the fork's clone URL.
433pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
434    if !metadata.is_fork {
435        // Not a fork - use regular remote
436        return detect_pr_remote(repo);
437    }
438
439    // Fork PR - need to add fork remote
440    let _fork_owner = metadata
441        .fork_owner
442        .as_ref()
443        .ok_or(PrError::MissingForkOwner)?;
444
445    let fork_url = metadata
446        .fork_url
447        .as_ref()
448        .ok_or(PrError::MissingForkOwner)?;
449
450    // Check if fork remote already exists
451    let fork_remote_name = format!("pr-{}-fork", metadata.number);
452
453    if repo.find_remote(&fork_remote_name).is_ok() {
454        debug!("Fork remote {} already exists", fork_remote_name);
455        return Ok(fork_remote_name);
456    }
457
458    // Add fork as remote
459    debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
460    repo.remote(&fork_remote_name, fork_url)
461        .map_err(|e| PrError::FetchFailed {
462            remote: fork_remote_name.clone(),
463            message: format!("Failed to add fork remote: {}", e),
464        })?;
465
466    Ok(fork_remote_name)
467}
468
469/// Fetch `branch` from `remote_name`, making it available as
470/// `refs/remotes/{remote_name}/{branch}`.
471///
472/// This is used for both fork and non-fork PRs to fetch the PR's head branch
473/// identified via `gh` CLI metadata. If the ref already exists locally the
474/// fetch is skipped.
475pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
476    // Check if branch already exists locally
477    let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
478    if repo.find_reference(&branch_ref).is_ok() {
479        debug!("Branch ref {} already exists", branch_ref);
480        return Ok(());
481    }
482
483    debug!("Fetching branch {} from remote {}", branch, remote_name);
484
485    let refspec = format!(
486        "+refs/heads/{}:refs/remotes/{}/{}",
487        branch, remote_name, branch
488    );
489
490    let remote_url = repo
491        .find_remote(remote_name)
492        .ok()
493        .and_then(|r| r.url().ok().map(str::to_string));
494    let auth = get_remote_callbacks(repo, remote_url.as_deref())?;
495    let mut fetch_options = FetchOptions::new();
496    fetch_options.remote_callbacks(auth.callbacks());
497
498    repo.find_remote(remote_name)?
499        .fetch(
500            &[refspec.as_str()],
501            Some(&mut fetch_options),
502            Some("Fetching PR branch"),
503        )
504        .map_err(|e| PrError::FetchFailed {
505            remote: remote_name.to_string(),
506            message: e.message().to_string(),
507        })?;
508
509    debug!("Successfully fetched branch {}", branch);
510    Ok(())
511}
512
513/// Format a PR worktree name using the format string
514///
515/// Replaces `{number}` placeholder with the PR number.
516pub fn format_pr_name(format: &str, pr_number: u32) -> String {
517    format.replace("{number}", &pr_number.to_string())
518}
519
520/// Prepare everything needed to create a worktree for PR `pr_number`.
521///
522/// Orchestrates the complete PR workflow:
523/// 1. Checks that `gh` CLI is available
524/// 2. Fetches PR metadata via `gh`
525/// 3. Sets up a fork remote if the PR is cross-repository
526/// 4. Fetches the PR's head branch
527/// 5. Formats the worktree name using `pr_format`
528///
529/// Returns `(worktree_name, remote_ref, base_branch)` ready for `add_worktree`.
530pub fn prepare_pr_worktree(
531    repo: &Repository,
532    pr_number: u32,
533    pr_format: &str,
534) -> Result<(String, String, String)> {
535    debug!("Preparing PR worktree for PR #{}", pr_number);
536
537    // Fetch PR metadata from gh CLI
538    let metadata = fetch_pr_metadata(pr_number)?;
539    debug!(
540        "Fetched metadata: title='{}', author='{}', is_fork={}",
541        metadata.title, metadata.author, metadata.is_fork
542    );
543
544    // Setup remote and fetch branch
545    // For fork PRs: setup fork remote and fetch from it
546    // For non-fork PRs: use existing remote (origin/upstream)
547    let remote_name = if metadata.is_fork {
548        setup_fork_remote(repo, &metadata)?
549    } else {
550        detect_pr_remote(repo)?
551    };
552
553    // Fetch the actual branch from gh CLI metadata (works for both fork and non-fork)
554    fetch_branch(repo, &remote_name, &metadata.head_ref)?;
555
556    // Format worktree name using metadata
557    let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
558    debug!("Worktree name: {}", worktree_name);
559
560    // Build remote ref using the actual branch from metadata
561    let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
562    debug!("Remote ref: {}", remote_ref);
563
564    Ok((worktree_name, remote_ref, metadata.base_ref))
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn test_parse_hash_number() {
573        let pr = parse_pr_reference("#123").unwrap().unwrap();
574        assert_eq!(pr.number, 123);
575        assert_eq!(pr.remote, None);
576    }
577
578    #[test]
579    fn test_parse_pr_hash_number() {
580        let pr = parse_pr_reference("pr#456").unwrap().unwrap();
581        assert_eq!(pr.number, 456);
582        assert_eq!(pr.remote, None);
583    }
584
585    #[test]
586    fn test_parse_pr_dash_number() {
587        let pr = parse_pr_reference("pr-789").unwrap().unwrap();
588        assert_eq!(pr.number, 789);
589        assert_eq!(pr.remote, None);
590    }
591
592    #[test]
593    fn test_parse_github_url() {
594        let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
595            .unwrap()
596            .unwrap();
597        assert_eq!(pr.number, 999);
598        assert_eq!(pr.remote, None);
599    }
600
601    #[test]
602    fn test_parse_remote_ref() {
603        let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
604        assert_eq!(pr.number, 111);
605        assert_eq!(pr.remote, None);
606    }
607
608    #[test]
609    fn test_parse_regular_branch_name() {
610        let result = parse_pr_reference("my-feature-branch").unwrap();
611        assert!(result.is_none());
612    }
613
614    #[test]
615    fn test_parse_invalid_number() {
616        let result = parse_pr_reference("#abc");
617        assert!(result.is_err());
618    }
619
620    #[test]
621    fn test_is_pr_reference_true() {
622        assert!(is_pr_reference("#123"));
623        assert!(is_pr_reference("pr#456"));
624        assert!(is_pr_reference("pr-789"));
625        assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
626    }
627
628    #[test]
629    fn test_is_pr_reference_false() {
630        assert!(!is_pr_reference("my-branch"));
631        assert!(!is_pr_reference("feature"));
632    }
633
634    #[test]
635    fn test_format_pr_name() {
636        assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
637        assert_eq!(format_pr_name("review-{number}", 456), "review-456");
638        assert_eq!(format_pr_name("{number}-test", 789), "789-test");
639    }
640
641    #[test]
642    fn test_sanitize_branch_name() {
643        assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
644        assert_eq!(
645            sanitize_for_branch_name("Add Feature (v2)"),
646            "add-feature-v2"
647        );
648        assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
649        assert_eq!(
650            sanitize_for_branch_name("Fix: Authentication Issue"),
651            "fix-authentication-issue"
652        );
653        assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
654    }
655
656    #[test]
657    fn test_format_with_metadata() {
658        let metadata = PrMetadata {
659            number: 123,
660            title: "Fix Authentication Bug".to_string(),
661            author: "john-smith".to_string(),
662            head_ref: "feature/fix-auth".to_string(),
663            base_ref: "main".to_string(),
664            is_fork: false,
665            fork_owner: None,
666            fork_url: None,
667        };
668
669        assert_eq!(
670            format_pr_name_with_metadata("pr-{number}", &metadata),
671            "pr-123"
672        );
673        assert_eq!(
674            format_pr_name_with_metadata("{number}-{title}", &metadata),
675            "123-fix-authentication-bug"
676        );
677        assert_eq!(
678            format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
679            "john-smith/pr-123"
680        );
681        assert_eq!(
682            format_pr_name_with_metadata("{branch}-{number}", &metadata),
683            "feature-fix-auth-123"
684        );
685    }
686
687    // Integration tests requiring gh CLI (marked with #[ignore])
688    #[test]
689    #[ignore]
690    fn test_gh_cli_available() {
691        check_gh_available().expect("gh CLI should be installed");
692    }
693
694    #[test]
695    #[ignore]
696    fn test_fetch_real_pr_metadata() {
697        // Requires gh CLI and auth
698        // This test uses a real PR from a public repo (git-workon itself if available)
699        // Replace with actual PR number from your repository for testing
700        let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
701        assert_eq!(metadata.number, 1);
702        assert!(!metadata.title.is_empty());
703        assert!(!metadata.author.is_empty());
704    }
705}