1use git2::{FetchOptions, Repository};
82use log::debug;
83
84use crate::{
85 error::{PrError, Result},
86 get_remote_callbacks,
87};
88
89#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct PullRequest {
92 pub number: u32,
94 pub remote: Option<String>,
96}
97
98#[derive(Debug, Clone)]
100pub struct PrMetadata {
101 pub number: u32,
103 pub title: String,
105 pub author: String,
107 pub head_ref: String,
109 pub base_ref: String,
111 pub is_fork: bool,
113 pub fork_owner: Option<String>,
115 pub fork_url: Option<String>,
117}
118
119pub fn parse_pr_reference(input: &str) -> Result<Option<PullRequest>> {
131 if let Some(num_str) = input.strip_prefix('#') {
133 return parse_number(num_str, input).map(|num| {
134 Some(PullRequest {
135 number: num,
136 remote: None,
137 })
138 });
139 }
140
141 if let Some(num_str) = input.strip_prefix("pr#") {
143 return parse_number(num_str, input).map(|num| {
144 Some(PullRequest {
145 number: num,
146 remote: None,
147 })
148 });
149 }
150
151 if let Some(num_str) = input.strip_prefix("pr-") {
153 return parse_number(num_str, input).map(|num| {
154 Some(PullRequest {
155 number: num,
156 remote: None,
157 })
158 });
159 }
160
161 if input.contains("github.com") && input.contains("/pull/") {
163 return parse_github_url(input);
164 }
165
166 if input.contains("/pull/") && input.ends_with("/head") {
168 return parse_remote_ref(input);
169 }
170
171 Ok(None)
173}
174
175fn parse_number(num_str: &str, original_input: &str) -> Result<u32> {
177 num_str.parse::<u32>().map_err(|_| {
178 PrError::InvalidReference {
179 input: original_input.to_string(),
180 }
181 .into()
182 })
183}
184
185fn parse_github_url(url: &str) -> Result<Option<PullRequest>> {
187 let parts: Vec<&str> = url.split('/').collect();
189
190 for (i, &part) in parts.iter().enumerate() {
192 if part == "pull" && i + 1 < parts.len() {
193 let num_str = parts[i + 1];
194 let number = parse_number(num_str, url)?;
195 return Ok(Some(PullRequest {
196 number,
197 remote: None,
198 }));
199 }
200 }
201
202 Err(PrError::InvalidReference {
203 input: url.to_string(),
204 }
205 .into())
206}
207
208fn parse_remote_ref(ref_str: &str) -> Result<Option<PullRequest>> {
210 let parts: Vec<&str> = ref_str.split('/').collect();
212
213 if parts.len() >= 4 && parts[parts.len() - 3] == "pull" && parts[parts.len() - 1] == "head" {
214 let num_str = parts[parts.len() - 2];
215 let number = parse_number(num_str, ref_str)?;
216 return Ok(Some(PullRequest {
217 number,
218 remote: None,
219 }));
220 }
221
222 Err(PrError::InvalidReference {
223 input: ref_str.to_string(),
224 }
225 .into())
226}
227
228pub fn check_gh_available() -> Result<()> {
232 std::process::Command::new("gh")
233 .arg("--version")
234 .output()
235 .map_err(|_| PrError::GhNotInstalled)?;
236 Ok(())
237}
238
239pub fn fetch_pr_metadata(pr_number: u32) -> Result<PrMetadata> {
244 check_gh_available()?;
246
247 let output = std::process::Command::new("gh")
249 .args([
250 "pr",
251 "view",
252 &pr_number.to_string(),
253 "--json",
254 "number,title,author,headRefName,baseRefName,isCrossRepository,headRepository",
255 ])
256 .output()
257 .map_err(|e| PrError::GhFetchFailed {
258 message: e.to_string(),
259 })?;
260
261 if !output.status.success() {
262 let stderr = String::from_utf8_lossy(&output.stderr);
263 return Err(PrError::GhFetchFailed {
264 message: stderr.to_string(),
265 }
266 .into());
267 }
268
269 let json_str = String::from_utf8_lossy(&output.stdout);
271 let json: serde_json::Value =
272 serde_json::from_str(&json_str).map_err(|e| PrError::GhJsonParseFailed {
273 message: e.to_string(),
274 })?;
275
276 let number = json["number"]
278 .as_u64()
279 .ok_or_else(|| PrError::GhJsonParseFailed {
280 message: "Missing 'number' field".to_string(),
281 })? as u32;
282
283 let title = json["title"]
284 .as_str()
285 .ok_or_else(|| PrError::GhJsonParseFailed {
286 message: "Missing 'title' field".to_string(),
287 })?
288 .to_string();
289
290 let author = json["author"]["login"]
291 .as_str()
292 .ok_or_else(|| PrError::GhJsonParseFailed {
293 message: "Missing 'author.login' field".to_string(),
294 })?
295 .to_string();
296
297 let head_ref = json["headRefName"]
298 .as_str()
299 .ok_or_else(|| PrError::GhJsonParseFailed {
300 message: "Missing 'headRefName' field".to_string(),
301 })?
302 .to_string();
303
304 let base_ref = json["baseRefName"]
305 .as_str()
306 .ok_or_else(|| PrError::GhJsonParseFailed {
307 message: "Missing 'baseRefName' field".to_string(),
308 })?
309 .to_string();
310
311 let is_fork = json["isCrossRepository"].as_bool().unwrap_or(false);
312
313 let (fork_owner, fork_url) = if is_fork {
314 let owner = json["headRepository"]["owner"]["login"]
315 .as_str()
316 .ok_or(PrError::MissingForkOwner)?
317 .to_string();
318 let url = json["headRepository"]["url"]
319 .as_str()
320 .map(|s| s.to_string());
321 (Some(owner), url)
322 } else {
323 (None, None)
324 };
325
326 Ok(PrMetadata {
327 number,
328 title,
329 author,
330 head_ref,
331 base_ref,
332 is_fork,
333 fork_owner,
334 fork_url,
335 })
336}
337
338fn sanitize_for_branch_name(s: &str) -> String {
340 let sanitized = s
341 .chars()
342 .map(|c| match c {
343 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
344 ' ' | '/' => '-',
345 _ => '-',
346 })
347 .collect::<String>()
348 .to_lowercase();
349
350 let mut result = String::new();
352 let mut last_was_dash = false;
353 for c in sanitized.chars() {
354 if c == '-' {
355 if !last_was_dash {
356 result.push(c);
357 }
358 last_was_dash = true;
359 } else {
360 result.push(c);
361 last_was_dash = false;
362 }
363 }
364
365 result.trim_matches(|c| c == '-' || c == '_').to_string()
366}
367
368pub fn format_pr_name_with_metadata(format: &str, metadata: &PrMetadata) -> String {
373 format
374 .replace("{number}", &metadata.number.to_string())
375 .replace("{title}", &sanitize_for_branch_name(&metadata.title))
376 .replace("{author}", &sanitize_for_branch_name(&metadata.author))
377 .replace("{branch}", &sanitize_for_branch_name(&metadata.head_ref))
378}
379
380pub fn is_pr_reference(input: &str) -> bool {
384 parse_pr_reference(input).ok().flatten().is_some()
385}
386
387pub fn remote_priority(remote: &str) -> usize {
394 match remote {
395 "upstream" => 0,
396 "origin" => 1,
397 _ => 2,
398 }
399}
400
401pub fn preferred_remote_order(repo: &Repository) -> Vec<String> {
404 let Ok(remotes) = repo.remotes() else {
405 return vec![];
406 };
407 let mut all: Vec<String> = remotes
408 .iter()
409 .flatten()
410 .flatten()
411 .map(str::to_string)
412 .collect();
413 all.sort_by_key(|r| remote_priority(r));
414 all
415}
416
417pub fn detect_pr_remote(repo: &Repository) -> Result<String> {
422 preferred_remote_order(repo)
423 .into_iter()
424 .next()
425 .ok_or_else(|| PrError::NoRemoteConfigured.into())
426}
427
428pub fn setup_fork_remote(repo: &Repository, metadata: &PrMetadata) -> Result<String> {
434 if !metadata.is_fork {
435 return detect_pr_remote(repo);
437 }
438
439 let _fork_owner = metadata
441 .fork_owner
442 .as_ref()
443 .ok_or(PrError::MissingForkOwner)?;
444
445 let fork_url = metadata
446 .fork_url
447 .as_ref()
448 .ok_or(PrError::MissingForkOwner)?;
449
450 let fork_remote_name = format!("pr-{}-fork", metadata.number);
452
453 if repo.find_remote(&fork_remote_name).is_ok() {
454 debug!("Fork remote {} already exists", fork_remote_name);
455 return Ok(fork_remote_name);
456 }
457
458 debug!("Adding fork remote: {} -> {}", fork_remote_name, fork_url);
460 repo.remote(&fork_remote_name, fork_url)
461 .map_err(|e| PrError::FetchFailed {
462 remote: fork_remote_name.clone(),
463 message: format!("Failed to add fork remote: {}", e),
464 })?;
465
466 Ok(fork_remote_name)
467}
468
469pub fn fetch_branch(repo: &Repository, remote_name: &str, branch: &str) -> Result<()> {
476 let branch_ref = format!("refs/remotes/{}/{}", remote_name, branch);
478 if repo.find_reference(&branch_ref).is_ok() {
479 debug!("Branch ref {} already exists", branch_ref);
480 return Ok(());
481 }
482
483 debug!("Fetching branch {} from remote {}", branch, remote_name);
484
485 let refspec = format!(
486 "+refs/heads/{}:refs/remotes/{}/{}",
487 branch, remote_name, branch
488 );
489
490 let remote_url = repo
491 .find_remote(remote_name)
492 .ok()
493 .and_then(|r| r.url().ok().map(str::to_string));
494 let auth = get_remote_callbacks(repo, remote_url.as_deref())?;
495 let mut fetch_options = FetchOptions::new();
496 fetch_options.remote_callbacks(auth.callbacks());
497
498 repo.find_remote(remote_name)?
499 .fetch(
500 &[refspec.as_str()],
501 Some(&mut fetch_options),
502 Some("Fetching PR branch"),
503 )
504 .map_err(|e| PrError::FetchFailed {
505 remote: remote_name.to_string(),
506 message: e.message().to_string(),
507 })?;
508
509 debug!("Successfully fetched branch {}", branch);
510 Ok(())
511}
512
513pub fn format_pr_name(format: &str, pr_number: u32) -> String {
517 format.replace("{number}", &pr_number.to_string())
518}
519
520pub fn prepare_pr_worktree(
531 repo: &Repository,
532 pr_number: u32,
533 pr_format: &str,
534) -> Result<(String, String, String)> {
535 debug!("Preparing PR worktree for PR #{}", pr_number);
536
537 let metadata = fetch_pr_metadata(pr_number)?;
539 debug!(
540 "Fetched metadata: title='{}', author='{}', is_fork={}",
541 metadata.title, metadata.author, metadata.is_fork
542 );
543
544 let remote_name = if metadata.is_fork {
548 setup_fork_remote(repo, &metadata)?
549 } else {
550 detect_pr_remote(repo)?
551 };
552
553 fetch_branch(repo, &remote_name, &metadata.head_ref)?;
555
556 let worktree_name = format_pr_name_with_metadata(pr_format, &metadata);
558 debug!("Worktree name: {}", worktree_name);
559
560 let remote_ref = format!("{}/{}", remote_name, metadata.head_ref);
562 debug!("Remote ref: {}", remote_ref);
563
564 Ok((worktree_name, remote_ref, metadata.base_ref))
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn test_parse_hash_number() {
573 let pr = parse_pr_reference("#123").unwrap().unwrap();
574 assert_eq!(pr.number, 123);
575 assert_eq!(pr.remote, None);
576 }
577
578 #[test]
579 fn test_parse_pr_hash_number() {
580 let pr = parse_pr_reference("pr#456").unwrap().unwrap();
581 assert_eq!(pr.number, 456);
582 assert_eq!(pr.remote, None);
583 }
584
585 #[test]
586 fn test_parse_pr_dash_number() {
587 let pr = parse_pr_reference("pr-789").unwrap().unwrap();
588 assert_eq!(pr.number, 789);
589 assert_eq!(pr.remote, None);
590 }
591
592 #[test]
593 fn test_parse_github_url() {
594 let pr = parse_pr_reference("https://github.com/owner/repo/pull/999")
595 .unwrap()
596 .unwrap();
597 assert_eq!(pr.number, 999);
598 assert_eq!(pr.remote, None);
599 }
600
601 #[test]
602 fn test_parse_remote_ref() {
603 let pr = parse_pr_reference("origin/pull/111/head").unwrap().unwrap();
604 assert_eq!(pr.number, 111);
605 assert_eq!(pr.remote, None);
606 }
607
608 #[test]
609 fn test_parse_regular_branch_name() {
610 let result = parse_pr_reference("my-feature-branch").unwrap();
611 assert!(result.is_none());
612 }
613
614 #[test]
615 fn test_parse_invalid_number() {
616 let result = parse_pr_reference("#abc");
617 assert!(result.is_err());
618 }
619
620 #[test]
621 fn test_is_pr_reference_true() {
622 assert!(is_pr_reference("#123"));
623 assert!(is_pr_reference("pr#456"));
624 assert!(is_pr_reference("pr-789"));
625 assert!(is_pr_reference("https://github.com/owner/repo/pull/999"));
626 }
627
628 #[test]
629 fn test_is_pr_reference_false() {
630 assert!(!is_pr_reference("my-branch"));
631 assert!(!is_pr_reference("feature"));
632 }
633
634 #[test]
635 fn test_format_pr_name() {
636 assert_eq!(format_pr_name("pr-{number}", 123), "pr-123");
637 assert_eq!(format_pr_name("review-{number}", 456), "review-456");
638 assert_eq!(format_pr_name("{number}-test", 789), "789-test");
639 }
640
641 #[test]
642 fn test_sanitize_branch_name() {
643 assert_eq!(sanitize_for_branch_name("Fix Bug #123"), "fix-bug-123");
644 assert_eq!(
645 sanitize_for_branch_name("Add Feature (v2)"),
646 "add-feature-v2"
647 );
648 assert_eq!(sanitize_for_branch_name("john-smith"), "john-smith");
649 assert_eq!(
650 sanitize_for_branch_name("Fix: Authentication Issue"),
651 "fix-authentication-issue"
652 );
653 assert_eq!(sanitize_for_branch_name("Test@#$%"), "test");
654 }
655
656 #[test]
657 fn test_format_with_metadata() {
658 let metadata = PrMetadata {
659 number: 123,
660 title: "Fix Authentication Bug".to_string(),
661 author: "john-smith".to_string(),
662 head_ref: "feature/fix-auth".to_string(),
663 base_ref: "main".to_string(),
664 is_fork: false,
665 fork_owner: None,
666 fork_url: None,
667 };
668
669 assert_eq!(
670 format_pr_name_with_metadata("pr-{number}", &metadata),
671 "pr-123"
672 );
673 assert_eq!(
674 format_pr_name_with_metadata("{number}-{title}", &metadata),
675 "123-fix-authentication-bug"
676 );
677 assert_eq!(
678 format_pr_name_with_metadata("{author}/pr-{number}", &metadata),
679 "john-smith/pr-123"
680 );
681 assert_eq!(
682 format_pr_name_with_metadata("{branch}-{number}", &metadata),
683 "feature-fix-auth-123"
684 );
685 }
686
687 #[test]
689 #[ignore]
690 fn test_gh_cli_available() {
691 check_gh_available().expect("gh CLI should be installed");
692 }
693
694 #[test]
695 #[ignore]
696 fn test_fetch_real_pr_metadata() {
697 let metadata = fetch_pr_metadata(1).expect("Failed to fetch PR metadata");
701 assert_eq!(metadata.number, 1);
702 assert!(!metadata.title.is_empty());
703 assert!(!metadata.author.is_empty());
704 }
705}