Skip to main content

workon/
pr.rs

1//! Pull request support for creating worktrees from PR references.
2//!
3//! This module enables creating worktrees directly from pull request references,
4//! making it easy to review PRs in isolated worktrees.
5//!
6//! ## PR Reference Parsing
7//!
8//! Supports multiple PR reference formats:
9//! - `#123` - GitHub shorthand (most common)
10//! - `pr#123` or `pr-123` - Explicit PR references
11//! - `https://github.com/owner/repo/pull/123` - Full GitHub PR URL
12//! - `origin/pull/123/head` - Direct remote ref (less common)
13//!
14//! Parsing is lenient - if it looks like a PR reference, we'll try to extract the number.
15//!
16//! ## Smart Routing
17//!
18//! The CLI's smart routing (in main.rs) automatically detects PR references:
19//! ```bash
20//! git workon #123        # Routes to `new` command with PR reference
21//! git workon pr#123      # Same - creates PR worktree
22//! git workon feature     # Routes to `find` command (not a PR)
23//! ```
24//!
25//! ## Remote Detection Algorithm
26//!
27//! To fetch PRs, we need to determine which remote to use. The detection strategy:
28//! 1. Check for `upstream` remote (common in fork workflows)
29//! 2. Fall back to `origin` remote (most common)
30//! 3. Use first available remote (rare, but handles edge cases)
31//!
32//! This handles both direct repository workflows and fork-based workflows.
33//!
34//! ## Auto-Fetch Strategy
35//!
36//! PR branches are fetched automatically using gh CLI metadata:
37//! ```text
38//! git fetch <remote> +refs/heads/{branch}:refs/remotes/<remote>/{branch}
39//! ```
40//!
41//! Where `{branch}` is the actual branch name from the PR (obtained via gh CLI).
42//! The `+` forces the fetch even if not fast-forward, ensuring we always get the latest PR state.
43//!
44//! For fork PRs, a fork remote is automatically added and the branch is fetched from it.
45//! For non-fork PRs, the branch is fetched from the detected remote (origin/upstream).
46//!
47//! ## Worktree Naming
48//!
49//! Worktree names are generated from `workon.prFormat` config (default: `pr-{number}`):
50//! - `pr-123` (default format)
51//! - `#123` (if configured with `#{number}`)
52//! - `pull-123` (if configured with `pull-{number}`)
53//!
54//! The format must contain `{number}` placeholder.
55//!
56//! ## Example Usage
57//!
58//! ```bash
59//! # Create worktree for PR #123 (auto-detects remote, auto-fetches)
60//! git workon #123
61//!
62//! # Explicit PR reference
63//! git workon new pr#456
64//!
65//! # From GitHub URL
66//! git workon new https://github.com/user/repo/pull/789
67//!
68//! # Configure custom naming
69//! git config workon.prFormat "review-{number}"
70//! git workon #123  # Creates worktree named "review-123"
71//! ```
72//!
73//! ## gh CLI Integration
74//!
75//! PR support integrates with gh CLI for rich metadata:
76//! - **Format placeholders**: {number}, {title}, {author}, {branch}
77//! - **Fork support**: Auto-adds fork remotes and fetches fork branches
78//! - **Metadata**: Fetches PR title, author, branch names, and state
79//! - **Validation**: Checks PR exists before creating worktree
80
81use git2::{FetchOptions, Repository};
82use log::debug;
83
84use crate::{
85    error::{PrError, Result},
86    get_remote_callbacks,
87};
88
89/// A parsed pull request reference from user input.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct PullRequest {
92    /// The PR number extracted from the reference string.
93    pub number: u32,
94    /// Optional remote name if the reference included one (e.g. `origin/pull/123/head`).
95    pub remote: Option<String>,
96}
97
98/// PR metadata fetched from the `gh` CLI.
99#[derive(Debug, Clone)]
100pub struct PrMetadata {
101    /// PR number.
102    pub number: u32,
103    /// PR title.
104    pub title: String,
105    /// GitHub login of the PR author.
106    pub author: String,
107    /// Name of the branch that the PR was created from.
108    pub head_ref: String,
109    /// Name of the branch the PR targets.
110    pub base_ref: String,
111    /// True if the PR comes from a forked repository.
112    pub is_fork: bool,
113    /// GitHub login of the fork owner, if this is a fork PR.
114    pub fork_owner: Option<String>,
115    /// Clone URL of the fork repository, if this is a fork PR.
116    pub fork_url: Option<String>,
117}
118
119/// Parse a PR reference from user input
120///
121/// Supported formats:
122/// - `#123` - GitHub shorthand
123/// - `pr#123` or `pr-123` - Explicit PR references
124/// - `https://github.com/owner/repo/pull/123` - GitHub PR URL
125/// - `origin/pull/123/head` - Direct remote ref
126///
127/// Returns `Ok(None)` if the input is not a PR reference.
128/// Returns `Ok(Some(PullRequest))` if successfully parsed.
129/// Returns `Err` if the input looks like a PR reference but is malformed.
130pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
131    // Try #123 format
132    if let Some(num_str) = input.strip_prefix('#') {
133        return parse_number(num_str, input).map(|num| {
134            Some(PullRequest {
135                number: num,
136                remote: None,
137            })
138        });
139    }
140
141    // Try pr#123 format
142    if let Some(num_str) = input.strip_prefix("pr#") {
143        return parse_number(num_str, input).map(|num| {
144            Some(PullRequest {
145                number: num,
146                remote: None,
147            })
148        });
149    }
150
151    // Try pr-123 format
152    if let Some(num_str) = input.strip_prefix("pr-") {
153        return parse_number(num_str, input).map(|num| {
154            Some(PullRequest {
155                number: num,
156                remote: None,
157            })
158        });
159    }
160
161    // Try GitHub URL: https://github.com/owner/repo/pull/123
162    if input.contains("github.com") && input.contains("/pull/") {
163        return parse_github_url(input);
164    }
165
166    // Try remote ref format: origin/pull/123/head
167    if input.contains("/pull/") && input.ends_with("/head") {
168        return parse_remote_ref(input);
169    }
170
171    // Not a PR reference
172    Ok(None)
173}
174
175/// Helper to parse a number string
176fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
177    num_str.parse::<u32>().map_err(|_| {
178        PrError::InvalidReference {
179            input: original_input.to_string(),
180        }
181        .into()
182    })
183}
184
185/// Parse GitHub PR URL
186fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
187    // Extract the PR number from URL like: https://github.com/owner/repo/pull/123
188    let parts: Vec<&str> = url.split('/').collect();
189
190    // Find "pull" in the path and get the number after it
191    for (i, &part) in parts.iter().enumerate() {
192        if part == "pull" && i + 1 < parts.len() {
193            let num_str = parts[i + 1];
194            let number = parse_number(num_str, url)?;
195            return Ok(Some(PullRequest {
196                number,
197                remote: None,
198            }));
199        }
200    }
201
202    Err(PrError::InvalidReference {
203        input: url.to_string(),
204    }
205    .into())
206}
207
208/// Parse remote ref format: origin/pull/123/head
209fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
210    // Format: remote/pull/number/head
211    let parts: Vec<&str> = ref_str.split('/').collect();
212
213    if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
214        let num_str = parts[parts.len() - 2];
215        let number = parse_number(num_str, ref_str)?;
216        return Ok(Some(PullRequest {
217            number,
218            remote: None,
219        }));
220    }
221
222    Err(PrError::InvalidReference {
223        input: ref_str.to_string(),
224    }
225    .into())
226}
227
228/// Return `Ok(())` if the `gh` CLI is installed and reachable in `PATH`.
229///
230/// Returns [`PrError::GhNotInstalled`] if `gh` cannot be executed.
231pub fn check_gh_available() -> Result<()> {
232    std::process::Command::new("gh")
233        .arg("--version")
234        .output()
235        .map_err(|_| PrError::GhNotInstalled)?;
236    Ok(())
237}
238
239/// Fetch PR metadata for `pr_number` using the `gh` CLI.
240///
241/// Runs `gh pr view <pr_number> --json ...` and parses the JSON output.
242/// Requires `gh` to be authenticated (`gh auth login`).
243pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
244    // Ensure gh is available
245    check_gh_available()?;
246
247    // Fetch PR metadata with single gh command
248    let output = std::process::Command::new("gh")
249        .args([
250            "pr",
251            "view",
252            &pr_number.to_string(),
253            "--json",
254            "number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
255        ])
256        .output()
257        .map_err(|e| PrError::GhFetchFailed {
258            message: e.to_string(),
259        })?;
260
261    if !output.status.success() {
262        let stderr = String::from_utf8_lossy(&output.stderr);
263        return Err(PrError::GhFetchFailed {
264            message: stderr.to_string(),
265        }
266        .into());
267    }
268
269    // Parse JSON response
270    let json_str = String::from_utf8_lossy(&output.stdout);
271    let json: serde_json::Value =
272        serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
273            message: e.to_string(),
274        })?;
275
276    // Extract fields
277    let number = json["number"]
278        .as_u64()
279        .ok_or_else(|| PrError::GhJsonParseFailed {
280            message: "Missing 'number' field".to_string(),
281        })? as u32;
282
283    let title = json["title"]
284        .as_str()
285        .ok_or_else(|| PrError::GhJsonParseFailed {
286            message: "Missing 'title' field".to_string(),
287        })?
288        .to_string();
289
290    let author = json["author"]["login"]
291        .as_str()
292        .ok_or_else(|| PrError::GhJsonParseFailed {
293            message: "Missing 'author.login' field".to_string(),
294        })?
295        .to_string();
296
297    let head_ref = json["headRefName"]
298        .as_str()
299        .ok_or_else(|| PrError::GhJsonParseFailed {
300            message: "Missing 'headRefName' field".to_string(),
301        })?
302        .to_string();
303
304    let base_ref = json["baseRefName"]
305        .as_str()
306        .ok_or_else(|| PrError::GhJsonParseFailed {
307            message: "Missing 'baseRefName' field".to_string(),
308        })?
309        .to_string();
310
311    let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
312
313    let (fork_owner, fork_url) = if is_fork {
314        let owner = json["headRepository"]["owner"]["login"]
315            .as_str()
316            .ok_or(PrError::MissingForkOwner)?
317            .to_string();
318        let url = json["headRepository"]["url"]
319            .as_str()
320            .map(|s| s.to_string());
321        (Some(owner), url)
322    } else {
323        (None, None)
324    };
325
326    Ok(PrMetadata {
327        number,
328        title,
329        author,
330        head_ref,
331        base_ref,
332        is_fork,
333        fork_owner,
334        fork_url,
335    })
336}
337
338/// Sanitize a string for use in branch/worktree names
339fn sanitize_for_branch_name(s: &str) -> String {
340    let sanitized = s
341        .chars()
342        .map(|c| match c {
343            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
344            ' ' | '/' => '-',
345            _ => '-',
346        })
347        .collect::<String>()
348        .to_lowercase();
349
350    // Collapse multiple dashes into single dash
351    let mut result = String::new();
352    let mut last_was_dash = false;
353    for c in sanitized.chars() {
354        if c == '-' {
355            if !last_was_dash {
356                result.push(c);
357            }
358            last_was_dash = true;
359        } else {
360            result.push(c);
361            last_was_dash = false;
362        }
363    }
364
365    result.trim_matches(|c| c == '-' || c == '_').to_string()
366}
367
368/// Expand all placeholders in `format` using `metadata`.
369///
370/// Supported placeholders: `{number}`, `{title}`, `{author}`, `{branch}`.
371/// Title, author, and branch values are sanitized for use in branch/directory names.
372pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
373    format
374        .replace("{number}", &metadata.number.to_string())
375        .replace("{title}", &sanitize_for_branch_name(&metadata.title))
376        .replace("{author}", &sanitize_for_branch_name(&metadata.author))
377        .replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
378}
379
380/// Check if a string looks like a PR reference
381///
382/// This is a quick check used for routing decisions.
383pub fn is_pr_reference(input: &str) -> bool {
384    parse_pr_reference(input).ok().flatten().is_some()
385}
386
387/// Select which remote to use for fetching PR refs.
388///
389/// Priority: `upstream` → `origin` → first available remote.
390/// Returns [`PrError::NoRemoteConfigured`] if the repository has no remotes.
391pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
392    let remotes = repo.remotes()?;
393
394    // Priority: upstream > origin
395    for name in &["upstream", "origin"] {
396        if remotes.iter().flatten().any(|r| r == *name) {
397            debug!("Using remote: {}", name);
398            return Ok(name.to_string());
399        }
400    }
401
402    // Fall back to first remote
403    if let Some(first_remote) = remotes.get(0) {
404        Ok(first_remote.to_string())
405    } else {
406        Err(PrError::NoRemoteConfigured.into())
407    }
408}
409
410/// Ensure a remote for a fork PR exists, then return its name.
411///
412/// For non-fork PRs this is equivalent to [`detect_pr_remote`].
413/// For fork PRs, a remote named `pr-{number}-fork` is added if it doesn't
414/// already exist, pointing at the fork's clone URL.
415pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
416    if !metadata.is_fork {
417        // Not a fork - use regular remote
418        return detect_pr_remote(repo);
419    }
420
421    // Fork PR - need to add fork remote
422    let _fork_owner = metadata
423        .fork_owner
424        .as_ref()
425        .ok_or(PrError::MissingForkOwner)?;
426
427    let fork_url = metadata
428        .fork_url
429        .as_ref()
430        .ok_or(PrError::MissingForkOwner)?;
431
432    // Check if fork remote already exists
433    let fork_remote_name = format!("pr-{}-fork", metadata.number);
434
435    if repo.find_remote(&fork_remote_name).is_ok() {
436        debug!("Fork remote {} already exists", fork_remote_name);
437        return Ok(fork_remote_name);
438    }
439
440    // Add fork as remote
441    debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
442    repo.remote(&fork_remote_name, fork_url)
443        .map_err(|e| PrError::FetchFailed {
444            remote: fork_remote_name.clone(),
445            message: format!("Failed to add fork remote: {}", e),
446        })?;
447
448    Ok(fork_remote_name)
449}
450
451/// Fetch `branch` from `remote_name`, making it available as
452/// `refs/remotes/{remote_name}/{branch}`.
453///
454/// This is used for both fork and non-fork PRs to fetch the PR's head branch
455/// identified via `gh` CLI metadata. If the ref already exists locally the
456/// fetch is skipped.
457pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
458    // Check if branch already exists locally
459    let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
460    if repo.find_reference(&branch_ref).is_ok() {
461        debug!("Branch ref {} already exists", branch_ref);
462        return Ok(());
463    }
464
465    debug!("Fetching branch {} from remote {}", branch, remote_name);
466
467    let refspec = format!(
468        "+refs/heads/{}:refs/remotes/{}/{}",
469        branch, remote_name, branch
470    );
471
472    let mut fetch_options = FetchOptions::new();
473    fetch_options.remote_callbacks(get_remote_callbacks()?);
474
475    repo.find_remote(remote_name)?
476        .fetch(
477            &[refspec.as_str()],
478            Some(&mut fetch_options),
479            Some("Fetching PR branch"),
480        )
481        .map_err(|e| PrError::FetchFailed {
482            remote: remote_name.to_string(),
483            message: e.message().to_string(),
484        })?;
485
486    debug!("Successfully fetched branch {}", branch);
487    Ok(())
488}
489
490/// Format a PR worktree name using the format string
491///
492/// Replaces `{number}` placeholder with the PR number.
493pub fn format_pr_name(format: &str, pr_number: u32) -> String {
494    format.replace("{number}", &pr_number.to_string())
495}
496
497/// Prepare everything needed to create a worktree for PR `pr_number`.
498///
499/// Orchestrates the complete PR workflow:
500/// 1. Checks that `gh` CLI is available
501/// 2. Fetches PR metadata via `gh`
502/// 3. Sets up a fork remote if the PR is cross-repository
503/// 4. Fetches the PR's head branch
504/// 5. Formats the worktree name using `pr_format`
505///
506/// Returns `(worktree_name, remote_ref, base_branch)` ready for `add_worktree`.
507pub fn prepare_pr_worktree(
508    repo: &Repository,
509    pr_number: u32,
510    pr_format: &str,
511) -> Result<(String, String, String)> {
512    debug!("Preparing PR worktree for PR #{}", pr_number);
513
514    // Fetch PR metadata from gh CLI
515    let metadata = fetch_pr_metadata(pr_number)?;
516    debug!(
517        "Fetched metadata: title='{}', author='{}', is_fork={}",
518        metadata.title, metadata.author, metadata.is_fork
519    );
520
521    // Setup remote and fetch branch
522    // For fork PRs: setup fork remote and fetch from it
523    // For non-fork PRs: use existing remote (origin/upstream)
524    let remote_name = if metadata.is_fork {
525        setup_fork_remote(repo, &metadata)?
526    } else {
527        detect_pr_remote(repo)?
528    };
529
530    // Fetch the actual branch from gh CLI metadata (works for both fork and non-fork)
531    fetch_branch(repo, &remote_name, &metadata.head_ref)?;
532
533    // Format worktree name using metadata
534    let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
535    debug!("Worktree name: {}", worktree_name);
536
537    // Build remote ref using the actual branch from metadata
538    let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
539    debug!("Remote ref: {}", remote_ref);
540
541    Ok((worktree_name, remote_ref, metadata.base_ref))
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_parse_hash_number() {
550        let pr = parse_pr_reference("#123").unwrap().unwrap();
551        assert_eq!(pr.number, 123);
552        assert_eq!(pr.remote, None);
553    }
554
555    #[test]
556    fn test_parse_pr_hash_number() {
557        let pr = parse_pr_reference("pr#456").unwrap().unwrap();
558        assert_eq!(pr.number, 456);
559        assert_eq!(pr.remote, None);
560    }
561
562    #[test]
563    fn test_parse_pr_dash_number() {
564        let pr = parse_pr_reference("pr-789").unwrap().unwrap();
565        assert_eq!(pr.number, 789);
566        assert_eq!(pr.remote, None);
567    }
568
569    #[test]
570    fn test_parse_github_url() {
571        let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
572            .unwrap()
573            .unwrap();
574        assert_eq!(pr.number, 999);
575        assert_eq!(pr.remote, None);
576    }
577
578    #[test]
579    fn test_parse_remote_ref() {
580        let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
581        assert_eq!(pr.number, 111);
582        assert_eq!(pr.remote, None);
583    }
584
585    #[test]
586    fn test_parse_regular_branch_name() {
587        let result = parse_pr_reference("my-feature-branch").unwrap();
588        assert!(result.is_none());
589    }
590
591    #[test]
592    fn test_parse_invalid_number() {
593        let result = parse_pr_reference("#abc");
594        assert!(result.is_err());
595    }
596
597    #[test]
598    fn test_is_pr_reference_true() {
599        assert!(is_pr_reference("#123"));
600        assert!(is_pr_reference("pr#456"));
601        assert!(is_pr_reference("pr-789"));
602        assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
603    }
604
605    #[test]
606    fn test_is_pr_reference_false() {
607        assert!(!is_pr_reference("my-branch"));
608        assert!(!is_pr_reference("feature"));
609    }
610
611    #[test]
612    fn test_format_pr_name() {
613        assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
614        assert_eq!(format_pr_name("review-{number}", 456), "review-456");
615        assert_eq!(format_pr_name("{number}-test", 789), "789-test");
616    }
617
618    #[test]
619    fn test_sanitize_branch_name() {
620        assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
621        assert_eq!(
622            sanitize_for_branch_name("Add Feature (v2)"),
623            "add-feature-v2"
624        );
625        assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
626        assert_eq!(
627            sanitize_for_branch_name("Fix: Authentication Issue"),
628            "fix-authentication-issue"
629        );
630        assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
631    }
632
633    #[test]
634    fn test_format_with_metadata() {
635        let metadata = PrMetadata {
636            number: 123,
637            title: "Fix Authentication Bug".to_string(),
638            author: "john-smith".to_string(),
639            head_ref: "feature/fix-auth".to_string(),
640            base_ref: "main".to_string(),
641            is_fork: false,
642            fork_owner: None,
643            fork_url: None,
644        };
645
646        assert_eq!(
647            format_pr_name_with_metadata("pr-{number}", &metadata),
648            "pr-123"
649        );
650        assert_eq!(
651            format_pr_name_with_metadata("{number}-{title}", &metadata),
652            "123-fix-authentication-bug"
653        );
654        assert_eq!(
655            format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
656            "john-smith/pr-123"
657        );
658        assert_eq!(
659            format_pr_name_with_metadata("{branch}-{number}", &metadata),
660            "feature-fix-auth-123"
661        );
662    }
663
664    // Integration tests requiring gh CLI (marked with #[ignore])
665    #[test]
666    #[ignore]
667    fn test_gh_cli_available() {
668        check_gh_available().expect("gh CLI should be installed");
669    }
670
671    #[test]
672    #[ignore]
673    fn test_fetch_real_pr_metadata() {
674        // Requires gh CLI and auth
675        // This test uses a real PR from a public repo (git-workon itself if available)
676        // Replace with actual PR number from your repository for testing
677        let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
678        assert_eq!(metadata.number, 1);
679        assert!(!metadata.title.is_empty());
680        assert!(!metadata.author.is_empty());
681    }
682}