Skip to main content

workon/
pr.rs

1//! Pull request support for creating worktrees from PR references.
2//!
3//! This module enables creating worktrees directly from pull request references,
4//! making it easy to review PRs in isolated worktrees.
5//!
6//! ## PR Reference Parsing
7//!
8//! Supports multiple PR reference formats:
9//! - `#123` - GitHub shorthand (most common)
10//! - `pr#123` or `pr-123` - Explicit PR references
11//! - `https://github.com/owner/repo/pull/123` - Full GitHub PR URL
12//! - `origin/pull/123/head` - Direct remote ref (less common)
13//!
14//! Parsing is lenient - if it looks like a PR reference, we'll try to extract the number.
15//!
16//! ## Smart Routing
17//!
18//! The CLI's smart routing (in main.rs) automatically detects PR references:
19//! ```bash
20//! git workon #123        # Routes to `new` command with PR reference
21//! git workon pr#123      # Same - creates PR worktree
22//! git workon feature     # Routes to `find` command (not a PR)
23//! ```
24//!
25//! ## Remote Detection Algorithm
26//!
27//! To fetch PRs, we need to determine which remote to use. The detection strategy:
28//! 1. Check for `upstream` remote (common in fork workflows)
29//! 2. Fall back to `origin` remote (most common)
30//! 3. Use first available remote (rare, but handles edge cases)
31//!
32//! This handles both direct repository workflows and fork-based workflows.
33//!
34//! ## Auto-Fetch Strategy
35//!
36//! PR branches are fetched automatically using gh CLI metadata:
37//! ```text
38//! git fetch <remote> +refs/heads/{branch}:refs/remotes/<remote>/{branch}
39//! ```
40//!
41//! Where `{branch}` is the actual branch name from the PR (obtained via gh CLI).
42//! The `+` forces the fetch even if not fast-forward, ensuring we always get the latest PR state.
43//!
44//! For fork PRs, a fork remote is automatically added and the branch is fetched from it.
45//! For non-fork PRs, the branch is fetched from the detected remote (origin/upstream).
46//!
47//! ## Worktree Naming
48//!
49//! Worktree names are generated from `workon.prFormat` config (default: `pr-{number}`):
50//! - `pr-123` (default format)
51//! - `#123` (if configured with `#{number}`)
52//! - `pull-123` (if configured with `pull-{number}`)
53//!
54//! The format must contain `{number}` placeholder.
55//!
56//! ## Example Usage
57//!
58//! ```bash
59//! # Create worktree for PR #123 (auto-detects remote, auto-fetches)
60//! git workon #123
61//!
62//! # Explicit PR reference
63//! git workon new pr#456
64//!
65//! # From GitHub URL
66//! git workon new https://github.com/user/repo/pull/789
67//!
68//! # Configure custom naming
69//! git config workon.prFormat "review-{number}"
70//! git workon #123  # Creates worktree named "review-123"
71//! ```
72//!
73//! ## gh CLI Integration
74//!
75//! PR support integrates with gh CLI for rich metadata:
76//! - **Format placeholders**: {number}, {title}, {author}, {branch}
77//! - **Fork support**: Auto-adds fork remotes and fetches fork branches
78//! - **Metadata**: Fetches PR title, author, branch names, and state
79//! - **Validation**: Checks PR exists before creating worktree
80
81use git2::{FetchOptions, Repository};
82use log::debug;
83
84use crate::{
85    error::{PrError, Result},
86    get_remote_callbacks,
87};
88
89/// A parsed pull request reference from user input.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct PullRequest {
92    /// The PR number extracted from the reference string.
93    pub number: u32,
94    /// Optional remote name if the reference included one (e.g. `origin/pull/123/head`).
95    pub remote: Option<String>,
96}
97
98/// PR metadata fetched from the `gh` CLI.
99#[derive(Debug, Clone)]
100pub struct PrMetadata {
101    /// PR number.
102    pub number: u32,
103    /// PR title.
104    pub title: String,
105    /// GitHub login of the PR author.
106    pub author: String,
107    /// Name of the branch that the PR was created from.
108    pub head_ref: String,
109    /// Name of the branch the PR targets.
110    pub base_ref: String,
111    /// True if the PR comes from a forked repository.
112    pub is_fork: bool,
113    /// GitHub login of the fork owner, if this is a fork PR.
114    pub fork_owner: Option<String>,
115    /// Clone URL of the fork repository, if this is a fork PR.
116    pub fork_url: Option<String>,
117}
118
119/// Parse a PR reference from user input
120///
121/// Supported formats:
122/// - `#123` - GitHub shorthand
123/// - `pr#123` or `pr-123` - Explicit PR references
124/// - `https://github.com/owner/repo/pull/123` - GitHub PR URL
125/// - `origin/pull/123/head` - Direct remote ref
126///
127/// Returns `Ok(None)` if the input is not a PR reference.
128/// Returns `Ok(Some(PullRequest))` if successfully parsed.
129/// Returns `Err` if the input looks like a PR reference but is malformed.
130pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
131    // Try #123 format
132    if let Some(num_str) = input.strip_prefix('#') {
133        return parse_number(num_str, input).map(|num| {
134            Some(PullRequest {
135                number: num,
136                remote: None,
137            })
138        });
139    }
140
141    // Try pr#123 format
142    if let Some(num_str) = input.strip_prefix("pr#") {
143        return parse_number(num_str, input).map(|num| {
144            Some(PullRequest {
145                number: num,
146                remote: None,
147            })
148        });
149    }
150
151    // Try pr-123 format
152    if let Some(num_str) = input.strip_prefix("pr-") {
153        return parse_number(num_str, input).map(|num| {
154            Some(PullRequest {
155                number: num,
156                remote: None,
157            })
158        });
159    }
160
161    // Try GitHub URL: https://github.com/owner/repo/pull/123
162    if input.contains("github.com") && input.contains("/pull/") {
163        return parse_github_url(input);
164    }
165
166    // Try remote ref format: origin/pull/123/head
167    if input.contains("/pull/") && input.ends_with("/head") {
168        return parse_remote_ref(input);
169    }
170
171    // Not a PR reference
172    Ok(None)
173}
174
175/// Helper to parse a number string
176fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
177    num_str.parse::<u32>().map_err(|_| {
178        PrError::InvalidReference {
179            input: original_input.to_string(),
180        }
181        .into()
182    })
183}
184
185/// Parse GitHub PR URL
186fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
187    // Extract the PR number from URL like: https://github.com/owner/repo/pull/123
188    let parts: Vec<&str> = url.split('/').collect();
189
190    // Find "pull" in the path and get the number after it
191    for (i, &part) in parts.iter().enumerate() {
192        if part == "pull" && i + 1 < parts.len() {
193            let num_str = parts[i + 1];
194            let number = parse_number(num_str, url)?;
195            return Ok(Some(PullRequest {
196                number,
197                remote: None,
198            }));
199        }
200    }
201
202    Err(PrError::InvalidReference {
203        input: url.to_string(),
204    }
205    .into())
206}
207
208/// Parse remote ref format: origin/pull/123/head
209fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
210    // Format: remote/pull/number/head
211    let parts: Vec<&str> = ref_str.split('/').collect();
212
213    if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
214        let num_str = parts[parts.len() - 2];
215        let number = parse_number(num_str, ref_str)?;
216        return Ok(Some(PullRequest {
217            number,
218            remote: None,
219        }));
220    }
221
222    Err(PrError::InvalidReference {
223        input: ref_str.to_string(),
224    }
225    .into())
226}
227
228/// Return `Ok(())` if the `gh` CLI is installed and reachable in `PATH`.
229///
230/// Returns [`PrError::GhNotInstalled`] if `gh` cannot be executed.
231pub fn check_gh_available() -> Result<()> {
232    std::process::Command::new("gh")
233        .arg("--version")
234        .output()
235        .map_err(|_| PrError::GhNotInstalled)?;
236    Ok(())
237}
238
239/// Fetch PR metadata for `pr_number` using the `gh` CLI.
240///
241/// Runs `gh pr view <pr_number> --json ...` and parses the JSON output.
242/// Requires `gh` to be authenticated (`gh auth login`).
243pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
244    // Ensure gh is available
245    check_gh_available()?;
246
247    // Fetch PR metadata with single gh command
248    let output = std::process::Command::new("gh")
249        .args([
250            "pr",
251            "view",
252            &pr_number.to_string(),
253            "--json",
254            "number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
255        ])
256        .output()
257        .map_err(|e| PrError::GhFetchFailed {
258            message: e.to_string(),
259        })?;
260
261    if !output.status.success() {
262        let stderr = String::from_utf8_lossy(&output.stderr);
263        return Err(PrError::GhFetchFailed {
264            message: stderr.to_string(),
265        }
266        .into());
267    }
268
269    // Parse JSON response
270    let json_str = String::from_utf8_lossy(&output.stdout);
271    let json: serde_json::Value =
272        serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
273            message: e.to_string(),
274        })?;
275
276    // Extract fields
277    let number = json["number"]
278        .as_u64()
279        .ok_or_else(|| PrError::GhJsonParseFailed {
280            message: "Missing 'number' field".to_string(),
281        })? as u32;
282
283    let title = json["title"]
284        .as_str()
285        .ok_or_else(|| PrError::GhJsonParseFailed {
286            message: "Missing 'title' field".to_string(),
287        })?
288        .to_string();
289
290    let author = json["author"]["login"]
291        .as_str()
292        .ok_or_else(|| PrError::GhJsonParseFailed {
293            message: "Missing 'author.login' field".to_string(),
294        })?
295        .to_string();
296
297    let head_ref = json["headRefName"]
298        .as_str()
299        .ok_or_else(|| PrError::GhJsonParseFailed {
300            message: "Missing 'headRefName' field".to_string(),
301        })?
302        .to_string();
303
304    let base_ref = json["baseRefName"]
305        .as_str()
306        .ok_or_else(|| PrError::GhJsonParseFailed {
307            message: "Missing 'baseRefName' field".to_string(),
308        })?
309        .to_string();
310
311    let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
312
313    let (fork_owner, fork_url) = if is_fork {
314        let owner = json["headRepository"]["owner"]["login"]
315            .as_str()
316            .ok_or(PrError::MissingForkOwner)?
317            .to_string();
318        let url = json["headRepository"]["url"]
319            .as_str()
320            .map(|s| s.to_string());
321        (Some(owner), url)
322    } else {
323        (None, None)
324    };
325
326    Ok(PrMetadata {
327        number,
328        title,
329        author,
330        head_ref,
331        base_ref,
332        is_fork,
333        fork_owner,
334        fork_url,
335    })
336}
337
338/// Sanitize a string for use in branch/worktree names
339fn sanitize_for_branch_name(s: &str) -> String {
340    let sanitized = s
341        .chars()
342        .map(|c| match c {
343            'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
344            ' ' | '/' => '-',
345            _ => '-',
346        })
347        .collect::<String>()
348        .to_lowercase();
349
350    // Collapse multiple dashes into single dash
351    let mut result = String::new();
352    let mut last_was_dash = false;
353    for c in sanitized.chars() {
354        if c == '-' {
355            if !last_was_dash {
356                result.push(c);
357            }
358            last_was_dash = true;
359        } else {
360            result.push(c);
361            last_was_dash = false;
362        }
363    }
364
365    result.trim_matches(|c| c == '-' || c == '_').to_string()
366}
367
368/// Expand all placeholders in `format` using `metadata`.
369///
370/// Supported placeholders: `{number}`, `{title}`, `{author}`, `{branch}`.
371/// Title, author, and branch values are sanitized for use in branch/directory names.
372pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
373    format
374        .replace("{number}", &metadata.number.to_string())
375        .replace("{title}", &sanitize_for_branch_name(&metadata.title))
376        .replace("{author}", &sanitize_for_branch_name(&metadata.author))
377        .replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
378}
379
380/// Check if a string looks like a PR reference
381///
382/// This is a quick check used for routing decisions.
383pub fn is_pr_reference(input: &str) -> bool {
384    parse_pr_reference(input).ok().flatten().is_some()
385}
386
387/// Select which remote to use for fetching PR refs.
388///
389/// Priority: `upstream` → `origin` → first available remote.
390/// Returns [`PrError::NoRemoteConfigured`] if the repository has no remotes.
391pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
392    let remotes = repo.remotes()?;
393
394    // Priority: upstream > origin
395    for name in &["upstream", "origin"] {
396        if remotes.iter().flatten().any(|r| r == *name) {
397            debug!("Using remote: {}", name);
398            return Ok(name.to_string());
399        }
400    }
401
402    // Fall back to first remote
403    if let Some(first_remote) = remotes.get(0) {
404        Ok(first_remote.to_string())
405    } else {
406        Err(PrError::NoRemoteConfigured.into())
407    }
408}
409
410/// Ensure a remote for a fork PR exists, then return its name.
411///
412/// For non-fork PRs this is equivalent to [`detect_pr_remote`].
413/// For fork PRs, a remote named `pr-{number}-fork` is added if it doesn't
414/// already exist, pointing at the fork's clone URL.
415pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
416    if !metadata.is_fork {
417        // Not a fork - use regular remote
418        return detect_pr_remote(repo);
419    }
420
421    // Fork PR - need to add fork remote
422    let _fork_owner = metadata
423        .fork_owner
424        .as_ref()
425        .ok_or(PrError::MissingForkOwner)?;
426
427    let fork_url = metadata
428        .fork_url
429        .as_ref()
430        .ok_or(PrError::MissingForkOwner)?;
431
432    // Check if fork remote already exists
433    let fork_remote_name = format!("pr-{}-fork", metadata.number);
434
435    if repo.find_remote(&fork_remote_name).is_ok() {
436        debug!("Fork remote {} already exists", fork_remote_name);
437        return Ok(fork_remote_name);
438    }
439
440    // Add fork as remote
441    debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
442    repo.remote(&fork_remote_name, fork_url)
443        .map_err(|e| PrError::FetchFailed {
444            remote: fork_remote_name.clone(),
445            message: format!("Failed to add fork remote: {}", e),
446        })?;
447
448    Ok(fork_remote_name)
449}
450
451/// Fetch `branch` from `remote_name`, making it available as
452/// `refs/remotes/{remote_name}/{branch}`.
453///
454/// This is used for both fork and non-fork PRs to fetch the PR's head branch
455/// identified via `gh` CLI metadata. If the ref already exists locally the
456/// fetch is skipped.
457pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
458    // Check if branch already exists locally
459    let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
460    if repo.find_reference(&branch_ref).is_ok() {
461        debug!("Branch ref {} already exists", branch_ref);
462        return Ok(());
463    }
464
465    debug!("Fetching branch {} from remote {}", branch, remote_name);
466
467    let refspec = format!(
468        "+refs/heads/{}:refs/remotes/{}/{}",
469        branch, remote_name, branch
470    );
471
472    let remote_url = repo
473        .find_remote(remote_name)
474        .ok()
475        .and_then(|r| r.url().map(str::to_string));
476    let mut fetch_options = FetchOptions::new();
477    fetch_options.remote_callbacks(get_remote_callbacks(remote_url.as_deref())?);
478
479    repo.find_remote(remote_name)?
480        .fetch(
481            &[refspec.as_str()],
482            Some(&mut fetch_options),
483            Some("Fetching PR branch"),
484        )
485        .map_err(|e| PrError::FetchFailed {
486            remote: remote_name.to_string(),
487            message: e.message().to_string(),
488        })?;
489
490    debug!("Successfully fetched branch {}", branch);
491    Ok(())
492}
493
494/// Format a PR worktree name using the format string
495///
496/// Replaces `{number}` placeholder with the PR number.
497pub fn format_pr_name(format: &str, pr_number: u32) -> String {
498    format.replace("{number}", &pr_number.to_string())
499}
500
501/// Prepare everything needed to create a worktree for PR `pr_number`.
502///
503/// Orchestrates the complete PR workflow:
504/// 1. Checks that `gh` CLI is available
505/// 2. Fetches PR metadata via `gh`
506/// 3. Sets up a fork remote if the PR is cross-repository
507/// 4. Fetches the PR's head branch
508/// 5. Formats the worktree name using `pr_format`
509///
510/// Returns `(worktree_name, remote_ref, base_branch)` ready for `add_worktree`.
511pub fn prepare_pr_worktree(
512    repo: &Repository,
513    pr_number: u32,
514    pr_format: &str,
515) -> Result<(String, String, String)> {
516    debug!("Preparing PR worktree for PR #{}", pr_number);
517
518    // Fetch PR metadata from gh CLI
519    let metadata = fetch_pr_metadata(pr_number)?;
520    debug!(
521        "Fetched metadata: title='{}', author='{}', is_fork={}",
522        metadata.title, metadata.author, metadata.is_fork
523    );
524
525    // Setup remote and fetch branch
526    // For fork PRs: setup fork remote and fetch from it
527    // For non-fork PRs: use existing remote (origin/upstream)
528    let remote_name = if metadata.is_fork {
529        setup_fork_remote(repo, &metadata)?
530    } else {
531        detect_pr_remote(repo)?
532    };
533
534    // Fetch the actual branch from gh CLI metadata (works for both fork and non-fork)
535    fetch_branch(repo, &remote_name, &metadata.head_ref)?;
536
537    // Format worktree name using metadata
538    let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
539    debug!("Worktree name: {}", worktree_name);
540
541    // Build remote ref using the actual branch from metadata
542    let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
543    debug!("Remote ref: {}", remote_ref);
544
545    Ok((worktree_name, remote_ref, metadata.base_ref))
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_parse_hash_number() {
554        let pr = parse_pr_reference("#123").unwrap().unwrap();
555        assert_eq!(pr.number, 123);
556        assert_eq!(pr.remote, None);
557    }
558
559    #[test]
560    fn test_parse_pr_hash_number() {
561        let pr = parse_pr_reference("pr#456").unwrap().unwrap();
562        assert_eq!(pr.number, 456);
563        assert_eq!(pr.remote, None);
564    }
565
566    #[test]
567    fn test_parse_pr_dash_number() {
568        let pr = parse_pr_reference("pr-789").unwrap().unwrap();
569        assert_eq!(pr.number, 789);
570        assert_eq!(pr.remote, None);
571    }
572
573    #[test]
574    fn test_parse_github_url() {
575        let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
576            .unwrap()
577            .unwrap();
578        assert_eq!(pr.number, 999);
579        assert_eq!(pr.remote, None);
580    }
581
582    #[test]
583    fn test_parse_remote_ref() {
584        let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
585        assert_eq!(pr.number, 111);
586        assert_eq!(pr.remote, None);
587    }
588
589    #[test]
590    fn test_parse_regular_branch_name() {
591        let result = parse_pr_reference("my-feature-branch").unwrap();
592        assert!(result.is_none());
593    }
594
595    #[test]
596    fn test_parse_invalid_number() {
597        let result = parse_pr_reference("#abc");
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_is_pr_reference_true() {
603        assert!(is_pr_reference("#123"));
604        assert!(is_pr_reference("pr#456"));
605        assert!(is_pr_reference("pr-789"));
606        assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
607    }
608
609    #[test]
610    fn test_is_pr_reference_false() {
611        assert!(!is_pr_reference("my-branch"));
612        assert!(!is_pr_reference("feature"));
613    }
614
615    #[test]
616    fn test_format_pr_name() {
617        assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
618        assert_eq!(format_pr_name("review-{number}", 456), "review-456");
619        assert_eq!(format_pr_name("{number}-test", 789), "789-test");
620    }
621
622    #[test]
623    fn test_sanitize_branch_name() {
624        assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
625        assert_eq!(
626            sanitize_for_branch_name("Add Feature (v2)"),
627            "add-feature-v2"
628        );
629        assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
630        assert_eq!(
631            sanitize_for_branch_name("Fix: Authentication Issue"),
632            "fix-authentication-issue"
633        );
634        assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
635    }
636
637    #[test]
638    fn test_format_with_metadata() {
639        let metadata = PrMetadata {
640            number: 123,
641            title: "Fix Authentication Bug".to_string(),
642            author: "john-smith".to_string(),
643            head_ref: "feature/fix-auth".to_string(),
644            base_ref: "main".to_string(),
645            is_fork: false,
646            fork_owner: None,
647            fork_url: None,
648        };
649
650        assert_eq!(
651            format_pr_name_with_metadata("pr-{number}", &metadata),
652            "pr-123"
653        );
654        assert_eq!(
655            format_pr_name_with_metadata("{number}-{title}", &metadata),
656            "123-fix-authentication-bug"
657        );
658        assert_eq!(
659            format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
660            "john-smith/pr-123"
661        );
662        assert_eq!(
663            format_pr_name_with_metadata("{branch}-{number}", &metadata),
664            "feature-fix-auth-123"
665        );
666    }
667
668    // Integration tests requiring gh CLI (marked with #[ignore])
669    #[test]
670    #[ignore]
671    fn test_gh_cli_available() {
672        check_gh_available().expect("gh CLI should be installed");
673    }
674
675    #[test]
676    #[ignore]
677    fn test_fetch_real_pr_metadata() {
678        // Requires gh CLI and auth
679        // This test uses a real PR from a public repo (git-workon itself if available)
680        // Replace with actual PR number from your repository for testing
681        let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
682        assert_eq!(metadata.number, 1);
683        assert!(!metadata.title.is_empty());
684        assert!(!metadata.author.is_empty());
685    }
686}