Skip to main content

workon/
pr.rs

1//! Pull request support for creating worktrees from PR references.
2//!
3//! This module enables creating worktrees directly from pull request references,
4//! making it easy to review PRs in isolated worktrees.
5//!
6//! ## PR Reference Parsing
7//!
8//! Supports multiple PR reference formats:
9//! - `#123` - GitHub shorthand (most common)
10//! - `pr#123` or `pr-123` - Explicit PR references
11//! - `https://github.com/owner/repo/pull/123` - Full GitHub PR URL
12//! - `origin/pull/123/head` - Direct remote ref (less common)
13//!
14//! Parsing is lenient - if it looks like a PR reference, we'll try to extract the number.
15//!
16//! ## Smart Routing
17//!
18//! The CLI's smart routing (in main.rs) automatically detects PR references:
19//! ```bash
20//! git workon #123        # Routes to `new` command with PR reference
21//! git workon pr#123      # Same - creates PR worktree
22//! git workon feature     # Routes to `find` command (not a PR)
23//! ```
24//!
25//! ## Remote Detection Algorithm
26//!
27//! To fetch PRs, we need to determine which remote to use. The detection strategy:
28//! 1. Check for `upstream` remote (common in fork workflows)
29//! 2. Fall back to `origin` remote (most common)
30//! 3. Use first available remote (rare, but handles edge cases)
31//!
32//! This handles both direct repository workflows and fork-based workflows.
33//!
34//! ## Auto-Fetch Strategy
35//!
36//! PR branches are fetched automatically using gh CLI metadata:
37//! ```text
38//! git fetch <remote> +refs/heads/{branch}:refs/remotes/<remote>/{branch}
39//! ```
40//!
41//! Where `{branch}` is the actual branch name from the PR (obtained via gh CLI).
42//! The `+` forces the fetch even if not fast-forward, ensuring we always get the latest PR state.
43//!
44//! For fork PRs, a fork remote is automatically added and the branch is fetched from it.
45//! For non-fork PRs, the branch is fetched from the detected remote (origin/upstream).
46//!
47//! ## Worktree Naming
48//!
49//! Worktree names are generated from `workon.prFormat` config (default: `pr-{number}`):
50//! - `pr-123` (default format)
51//! - `#123` (if configured with `#{number}`)
52//! - `pull-123` (if configured with `pull-{number}`)
53//!
54//! The format must contain `{number}` placeholder.
55//!
56//! ## Example Usage
57//!
58//! ```bash
59//! # Create worktree for PR #123 (auto-detects remote, auto-fetches)
60//! git workon #123
61//!
62//! # Explicit PR reference
63//! git workon new pr#456
64//!
65//! # From GitHub URL
66//! git workon new https://github.com/user/repo/pull/789
67//!
68//! # Configure custom naming
69//! git config workon.prFormat "review-{number}"
70//! git workon #123  # Creates worktree named "review-123"
71//! ```
72//!
73//! ## gh CLI Integration
74//!
75//! PR support integrates with gh CLI for rich metadata:
76//! - **Format placeholders**: {number}, {title}, {author}, {branch}
77//! - **Fork support**: Auto-adds fork remotes and fetches fork branches
78//! - **Metadata**: Fetches PR title, author, branch names, and state
79//! - **Validation**: Checks PR exists before creating worktree
80
81use git2::{FetchOptions, Repository};
82use log::debug;
83
84use crate::{
85    error::{PrError, Result},
86    get_remote_callbacks,
87};
88
89/// A parsed pull request reference from user input.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct PullRequest {
92    /// The PR number extracted from the reference string.
93    pub number: u32,
94    /// Optional remote name if the reference included one (e.g. `origin/pull/123/head`).
95    pub remote: Option<String>,
96}
97
98/// PR metadata fetched from the `gh` CLI.
99#[derive(Debug, Clone)]
100pub struct PrMetadata {
101    /// PR number.
102    pub number: u32,
103    /// PR title.
104    pub title: String,
105    /// GitHub login of the PR author.
106    pub author: String,
107    /// Name of the branch that the PR was created from.
108    pub head_ref: String,
109    /// Name of the branch the PR targets.
110    pub base_ref: String,
111    /// True if the PR comes from a forked repository.
112    pub is_fork: bool,
113    /// GitHub login of the fork owner, if this is a fork PR.
114    pub fork_owner: Option<String>,
115    /// Clone URL of the fork repository, if this is a fork PR.
116    pub fork_url: Option<String>,
117}
118
119/// Parse a PR reference from user input
120///
121/// Supported formats:
122/// - `#123` - GitHub shorthand
123/// - `pr#123` or `pr-123` - Explicit PR references
124/// - `https://github.com/owner/repo/pull/123` - GitHub PR URL
125/// - `origin/pull/123/head` - Direct remote ref
126///
127/// Returns `Ok(None)` if the input is not a PR reference.
128/// Returns `Ok(Some(PullRequest))` if successfully parsed.
129/// Returns `Err` if the input looks like a PR reference but is malformed.
130pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
131    // Try #123 format
132    if let Some(num_str) = input.strip_prefix('#') {
133        return parse_number(num_str, input).map(|num| {
134            Some(PullRequest {
135                number: num,
136                remote: None,
137            })
138        });
139    }
140
141    // Try pr#123 format
142    if let Some(num_str) = input.strip_prefix("pr#") {
143        return parse_number(num_str, input).map(|num| {
144            Some(PullRequest {
145                number: num,
146                remote: None,
147            })
148        });
149    }
150
151    // Try pr-123 format
152    if let Some(num_str) = input.strip_prefix("pr-") {
153        return parse_number(num_str, input).map(|num| {
154            Some(PullRequest {
155                number: num,
156                remote: None,
157            })
158        });
159    }
160
161    // Try GitHub URL: https://github.com/owner/repo/pull/123
162    if input.contains("github.com") && input.contains("/pull/") {
163        return parse_github_url(input);
164    }
165
166    // Try remote ref format: origin/pull/123/head
167    if input.contains("/pull/") && input.ends_with("/head") {
168        return parse_remote_ref(input);
169    }
170
171    // Not a PR reference
172    Ok(None)
173}
174
175/// Helper to parse a number string
176fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
177    num_str.parse::<u32>().map_err(|_| {
178        PrError::InvalidReference {
179            input: original_input.to_string(),
180        }
181        .into()
182    })
183}
184
185/// Parse GitHub PR URL
186fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
187    // Extract the PR number from URL like: https://github.com/owner/repo/pull/123
188    let parts: Vec<&str> = url.split('/').collect();
189
190    // Find "pull" in the path and get the number after it
191    for (i, &part) in parts.iter().enumerate() {
192        if part == "pull" && i + 1 < parts.len() {
193            let num_str = parts[i + 1];
194            let number = parse_number(num_str, url)?;
195            return Ok(Some(PullRequest {
196                number,
197                remote: None,
198            }));
199        }
200    }
201
202    Err(PrError::InvalidReference {
203        input: url.to_string(),
204    }
205    .into())
206}
207
208/// Parse remote ref format: origin/pull/123/head
209fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
210    // Format: remote/pull/number/head
211    let parts: Vec<&str> = ref_str.split('/').collect();
212
213    if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
214        let num_str = parts[parts.len() - 2];
215        let number = parse_number(num_str, ref_str)?;
216        return Ok(Some(PullRequest {
217            number,
218            remote: None,
219        }));
220    }
221
222    Err(PrError::InvalidReference {
223        input: ref_str.to_string(),
224    }
225    .into())
226}
227
228/// Return `Ok(())` if the `gh` CLI is installed and reachable in `PATH`.
229///
230/// Returns [`PrError::GhNotInstalled`] if `gh` cannot be executed.
231pub fn check_gh_available() -> Result<()> {
232    std::process::Command::new("gh")
233        .arg("--version")
234        .output()
235        .map_err(|_| PrError::GhNotInstalled)?;
236    Ok(())
237}
238
239/// Fetch PR metadata for `pr_number` using the `gh` CLI.
240///
241/// Runs `gh pr view <pr_number> --json ...` and parses the JSON output.
242/// Requires `gh` to be authenticated (`gh auth login`).
243pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
244    // Ensure gh is available
245    check_gh_available()?;
246
247    // Fetch PR metadata with single gh command
248    let output = std::process::Command::new("gh")
249        .args([
250            "pr",
251            "view",
252            &pr_number.to_string(),
253            "--json",
254            "number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
255        ])
256        .output()
257        .map_err(|e| PrError::GhFetchFailed {
258            message: e.to_string(),
259        })?;
260
261    if !output.status.success() {
262        let stderr = String::from_utf8_lossy(&output.stderr);
263        return Err(PrError::GhFetchFailed {
264            message: stderr.to_string(),
265        }
266        .into());
267    }
268
269    // Parse JSON response
270    let json_str = String::from_utf8_lossy(&output.stdout);
271    let json: serde_json::Value =
272        serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
273            message: e.to_string(),
274        })?;
275
276    // Extract fields
277    let number = json["number"]
278        .as_u64()
279        .ok_or_else(|| PrError::GhJsonParseFailed {
280            message: "Missing 'number' field".to_string(),
281        })? as u32;
282
283    let title = json["title"]
284        .as_str()
285        .ok_or_else(|| PrError::GhJsonParseFailed {
286            message: "Missing 'title' field".to_string(),
287        })?
288        .to_string();
289
290    let author = json["author"]["login"]
291        .as_str()
292        .ok_or_else(|| PrError::GhJsonParseFailed {
293            message: "Missing 'author.login' field".to_string(),
294        })?
295        .to_string();
296
297    let head_ref = json["headRefName"]
298        .as_str()
299        .ok_or_else(|| PrError::GhJsonParseFailed {
300            message: "Missing 'headRefName' field".to_string(),
301        })?
302        .to_string();
303
304    let base_ref = json["baseRefName"]
305        .as_str()
306        .ok_or_else(|| PrError::GhJsonParseFailed {
307            message: "Missing 'baseRefName' field".to_string(),
308        })?
309        .to_string();
310
311    let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
312
313    let (fork_owner, fork_url) = if is_fork {
314        let owner = json["headRepository"]["owner"]["login"]
315            .as_str()
316            .ok_or(PrError::MissingForkOwner)?
317            .to_string();
318        let url = json["headRepository"]["url"]
319            .as_str()
320            .map(|s| s.to_string());
321        (Some(owner), url)
322    } else {
323        (None, None)
324    };
325
326    Ok(PrMetadata {
327        number,
328        title,
329        author,
330        head_ref,
331        base_ref,
332        is_fork,
333        fork_owner,
334        fork_url,
335    })
336}
337
338/// Sanitize a string for use in branch/worktree names
339fn sanitize_for_branch_name(s: &str) -> String {
340    let sanitized = s
341        .chars()
342        .map(|c| match c {
343            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
344            ' ' | '/' => '-',
345            _ => '-',
346        })
347        .collect::<String>()
348        .to_lowercase();
349
350    // Collapse multiple dashes into single dash
351    let mut result = String::new();
352    let mut last_was_dash = false;
353    for c in sanitized.chars() {
354        if c == '-' {
355            if !last_was_dash {
356                result.push(c);
357            }
358            last_was_dash = true;
359        } else {
360            result.push(c);
361            last_was_dash = false;
362        }
363    }
364
365    result.trim_matches(|c| c == '-' || c == '_').to_string()
366}
367
368/// Expand all placeholders in `format` using `metadata`.
369///
370/// Supported placeholders: `{number}`, `{title}`, `{author}`, `{branch}`.
371/// Title, author, and branch values are sanitized for use in branch/directory names.
372pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
373    format
374        .replace("{number}", &metadata.number.to_string())
375        .replace("{title}", &sanitize_for_branch_name(&metadata.title))
376        .replace("{author}", &sanitize_for_branch_name(&metadata.author))
377        .replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
378}
379
380/// Check if a string looks like a PR reference
381///
382/// This is a quick check used for routing decisions.
383pub fn is_pr_reference(input: &str) -> bool {
384    parse_pr_reference(input).ok().flatten().is_some()
385}
386
387/// Select which remote to use for fetching PR refs.
388///
389/// Priority: `upstream` → `origin` → first available remote.
390/// Returns [`PrError::NoRemoteConfigured`] if the repository has no remotes.
391pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
392    let remotes = repo.remotes()?;
393
394    // Priority: upstream > origin
395    for name in &["upstream", "origin"] {
396        if remotes.iter().flatten().flatten().any(|r| r == *name) {
397            debug!("Using remote: {}", name);
398            return Ok(name.to_string());
399        }
400    }
401
402    // Fall back to first remote
403    if let Ok(Some(first_remote)) = remotes.get(0) {
404        Ok(first_remote.to_string())
405    } else {
406        Err(PrError::NoRemoteConfigured.into())
407    }
408}
409
410/// Ensure a remote for a fork PR exists, then return its name.
411///
412/// For non-fork PRs this is equivalent to [`detect_pr_remote`].
413/// For fork PRs, a remote named `pr-{number}-fork` is added if it doesn't
414/// already exist, pointing at the fork's clone URL.
415pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
416    if !metadata.is_fork {
417        // Not a fork - use regular remote
418        return detect_pr_remote(repo);
419    }
420
421    // Fork PR - need to add fork remote
422    let _fork_owner = metadata
423        .fork_owner
424        .as_ref()
425        .ok_or(PrError::MissingForkOwner)?;
426
427    let fork_url = metadata
428        .fork_url
429        .as_ref()
430        .ok_or(PrError::MissingForkOwner)?;
431
432    // Check if fork remote already exists
433    let fork_remote_name = format!("pr-{}-fork", metadata.number);
434
435    if repo.find_remote(&fork_remote_name).is_ok() {
436        debug!("Fork remote {} already exists", fork_remote_name);
437        return Ok(fork_remote_name);
438    }
439
440    // Add fork as remote
441    debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
442    repo.remote(&fork_remote_name, fork_url)
443        .map_err(|e| PrError::FetchFailed {
444            remote: fork_remote_name.clone(),
445            message: format!("Failed to add fork remote: {}", e),
446        })?;
447
448    Ok(fork_remote_name)
449}
450
451/// Fetch `branch` from `remote_name`, making it available as
452/// `refs/remotes/{remote_name}/{branch}`.
453///
454/// This is used for both fork and non-fork PRs to fetch the PR's head branch
455/// identified via `gh` CLI metadata. If the ref already exists locally the
456/// fetch is skipped.
457pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
458    // Check if branch already exists locally
459    let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
460    if repo.find_reference(&branch_ref).is_ok() {
461        debug!("Branch ref {} already exists", branch_ref);
462        return Ok(());
463    }
464
465    debug!("Fetching branch {} from remote {}", branch, remote_name);
466
467    let refspec = format!(
468        "+refs/heads/{}:refs/remotes/{}/{}",
469        branch, remote_name, branch
470    );
471
472    let remote_url = repo
473        .find_remote(remote_name)
474        .ok()
475        .and_then(|r| r.url().ok().map(str::to_string));
476    let auth = get_remote_callbacks(repo, remote_url.as_deref())?;
477    let mut fetch_options = FetchOptions::new();
478    fetch_options.remote_callbacks(auth.callbacks());
479
480    repo.find_remote(remote_name)?
481        .fetch(
482            &[refspec.as_str()],
483            Some(&mut fetch_options),
484            Some("Fetching PR branch"),
485        )
486        .map_err(|e| PrError::FetchFailed {
487            remote: remote_name.to_string(),
488            message: e.message().to_string(),
489        })?;
490
491    debug!("Successfully fetched branch {}", branch);
492    Ok(())
493}
494
495/// Format a PR worktree name using the format string
496///
497/// Replaces `{number}` placeholder with the PR number.
498pub fn format_pr_name(format: &str, pr_number: u32) -> String {
499    format.replace("{number}", &pr_number.to_string())
500}
501
502/// Prepare everything needed to create a worktree for PR `pr_number`.
503///
504/// Orchestrates the complete PR workflow:
505/// 1. Checks that `gh` CLI is available
506/// 2. Fetches PR metadata via `gh`
507/// 3. Sets up a fork remote if the PR is cross-repository
508/// 4. Fetches the PR's head branch
509/// 5. Formats the worktree name using `pr_format`
510///
511/// Returns `(worktree_name, remote_ref, base_branch)` ready for `add_worktree`.
512pub fn prepare_pr_worktree(
513    repo: &Repository,
514    pr_number: u32,
515    pr_format: &str,
516) -> Result<(String, String, String)> {
517    debug!("Preparing PR worktree for PR #{}", pr_number);
518
519    // Fetch PR metadata from gh CLI
520    let metadata = fetch_pr_metadata(pr_number)?;
521    debug!(
522        "Fetched metadata: title='{}', author='{}', is_fork={}",
523        metadata.title, metadata.author, metadata.is_fork
524    );
525
526    // Setup remote and fetch branch
527    // For fork PRs: setup fork remote and fetch from it
528    // For non-fork PRs: use existing remote (origin/upstream)
529    let remote_name = if metadata.is_fork {
530        setup_fork_remote(repo, &metadata)?
531    } else {
532        detect_pr_remote(repo)?
533    };
534
535    // Fetch the actual branch from gh CLI metadata (works for both fork and non-fork)
536    fetch_branch(repo, &remote_name, &metadata.head_ref)?;
537
538    // Format worktree name using metadata
539    let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
540    debug!("Worktree name: {}", worktree_name);
541
542    // Build remote ref using the actual branch from metadata
543    let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
544    debug!("Remote ref: {}", remote_ref);
545
546    Ok((worktree_name, remote_ref, metadata.base_ref))
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_parse_hash_number() {
555        let pr = parse_pr_reference("#123").unwrap().unwrap();
556        assert_eq!(pr.number, 123);
557        assert_eq!(pr.remote, None);
558    }
559
560    #[test]
561    fn test_parse_pr_hash_number() {
562        let pr = parse_pr_reference("pr#456").unwrap().unwrap();
563        assert_eq!(pr.number, 456);
564        assert_eq!(pr.remote, None);
565    }
566
567    #[test]
568    fn test_parse_pr_dash_number() {
569        let pr = parse_pr_reference("pr-789").unwrap().unwrap();
570        assert_eq!(pr.number, 789);
571        assert_eq!(pr.remote, None);
572    }
573
574    #[test]
575    fn test_parse_github_url() {
576        let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
577            .unwrap()
578            .unwrap();
579        assert_eq!(pr.number, 999);
580        assert_eq!(pr.remote, None);
581    }
582
583    #[test]
584    fn test_parse_remote_ref() {
585        let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
586        assert_eq!(pr.number, 111);
587        assert_eq!(pr.remote, None);
588    }
589
590    #[test]
591    fn test_parse_regular_branch_name() {
592        let result = parse_pr_reference("my-feature-branch").unwrap();
593        assert!(result.is_none());
594    }
595
596    #[test]
597    fn test_parse_invalid_number() {
598        let result = parse_pr_reference("#abc");
599        assert!(result.is_err());
600    }
601
602    #[test]
603    fn test_is_pr_reference_true() {
604        assert!(is_pr_reference("#123"));
605        assert!(is_pr_reference("pr#456"));
606        assert!(is_pr_reference("pr-789"));
607        assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
608    }
609
610    #[test]
611    fn test_is_pr_reference_false() {
612        assert!(!is_pr_reference("my-branch"));
613        assert!(!is_pr_reference("feature"));
614    }
615
616    #[test]
617    fn test_format_pr_name() {
618        assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
619        assert_eq!(format_pr_name("review-{number}", 456), "review-456");
620        assert_eq!(format_pr_name("{number}-test", 789), "789-test");
621    }
622
623    #[test]
624    fn test_sanitize_branch_name() {
625        assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
626        assert_eq!(
627            sanitize_for_branch_name("Add Feature (v2)"),
628            "add-feature-v2"
629        );
630        assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
631        assert_eq!(
632            sanitize_for_branch_name("Fix: Authentication Issue"),
633            "fix-authentication-issue"
634        );
635        assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
636    }
637
638    #[test]
639    fn test_format_with_metadata() {
640        let metadata = PrMetadata {
641            number: 123,
642            title: "Fix Authentication Bug".to_string(),
643            author: "john-smith".to_string(),
644            head_ref: "feature/fix-auth".to_string(),
645            base_ref: "main".to_string(),
646            is_fork: false,
647            fork_owner: None,
648            fork_url: None,
649        };
650
651        assert_eq!(
652            format_pr_name_with_metadata("pr-{number}", &metadata),
653            "pr-123"
654        );
655        assert_eq!(
656            format_pr_name_with_metadata("{number}-{title}", &metadata),
657            "123-fix-authentication-bug"
658        );
659        assert_eq!(
660            format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
661            "john-smith/pr-123"
662        );
663        assert_eq!(
664            format_pr_name_with_metadata("{branch}-{number}", &metadata),
665            "feature-fix-auth-123"
666        );
667    }
668
669    // Integration tests requiring gh CLI (marked with #[ignore])
670    #[test]
671    #[ignore]
672    fn test_gh_cli_available() {
673        check_gh_available().expect("gh CLI should be installed");
674    }
675
676    #[test]
677    #[ignore]
678    fn test_fetch_real_pr_metadata() {
679        // Requires gh CLI and auth
680        // This test uses a real PR from a public repo (git-workon itself if available)
681        // Replace with actual PR number from your repository for testing
682        let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
683        assert_eq!(metadata.number, 1);
684        assert!(!metadata.title.is_empty());
685        assert!(!metadata.author.is_empty());
686    }
687}