1use anyhow::{Context, Result};
9use octocrab::Octocrab;
10use tracing::{debug, instrument};
11
12use super::{ReferenceKind, parse_github_reference};
13use crate::ai::types::{PrDetails, PrFile, PrReviewComment, ReviewEvent};
14use crate::error::{AptuError, ResourceType};
15use crate::triage::render_pr_review_comment_body;
16
17#[derive(Debug, serde::Serialize)]
19pub struct PrCreateResult {
20 pub pr_number: u64,
22 pub url: String,
24 pub branch: String,
26 pub base: String,
28 pub title: String,
30 pub draft: bool,
32 pub files_changed: u32,
34 pub additions: u64,
36 pub deletions: u64,
38}
39
40pub fn parse_pr_reference(
60 reference: &str,
61 repo_context: Option<&str>,
62) -> Result<(String, String, u64)> {
63 parse_github_reference(ReferenceKind::Pull, reference, repo_context)
64}
65
66#[instrument(skip(client), fields(owner = %owner, repo = %repo, number = number))]
85pub async fn fetch_pr_details(
86 client: &Octocrab,
87 owner: &str,
88 repo: &str,
89 number: u64,
90 review_config: &crate::config::ReviewConfig,
91) -> Result<PrDetails> {
92 debug!("Fetching PR details");
93
94 let pr = match client.pulls(owner, repo).get(number).await {
96 Ok(pr) => pr,
97 Err(e) => {
98 if let octocrab::Error::GitHub { source, .. } = &e
100 && source.status_code == 404
101 {
102 if (client.issues(owner, repo).get(number).await).is_ok() {
104 return Err(AptuError::TypeMismatch {
105 number,
106 expected: ResourceType::PullRequest,
107 actual: ResourceType::Issue,
108 }
109 .into());
110 }
111 }
113 return Err(e)
114 .with_context(|| format!("Failed to fetch PR #{number} from {owner}/{repo}"));
115 }
116 };
117
118 let files = client
120 .pulls(owner, repo)
121 .list_files(number)
122 .await
123 .with_context(|| format!("Failed to fetch files for PR #{number}"))?;
124
125 let pr_files: Vec<PrFile> = files
127 .items
128 .into_iter()
129 .map(|f| PrFile {
130 filename: f.filename,
131 status: format!("{:?}", f.status),
132 additions: f.additions,
133 deletions: f.deletions,
134 patch: f.patch,
135 full_content: None,
136 })
137 .collect();
138
139 let file_contents = fetch_file_contents(
141 client,
142 owner,
143 repo,
144 &pr_files,
145 &pr.head.sha,
146 review_config.max_full_content_files,
147 review_config.max_chars_per_file,
148 )
149 .await;
150
151 debug_assert_eq!(
153 pr_files.len(),
154 file_contents.len(),
155 "fetch_file_contents must return one entry per file"
156 );
157 let pr_files: Vec<PrFile> = pr_files
158 .into_iter()
159 .zip(file_contents)
160 .map(|(mut file, content)| {
161 file.full_content = content;
162 file
163 })
164 .collect();
165
166 let labels: Vec<String> = pr.labels.iter().map(|l| l.name.clone()).collect();
167
168 let details = PrDetails {
169 owner: owner.to_string(),
170 repo: repo.to_string(),
171 number,
172 title: pr.title.clone(),
173 body: pr.body.clone().unwrap_or_default(),
174 base_branch: pr.base.ref_field,
175 head_branch: pr.head.ref_field,
176 head_sha: pr.head.sha,
177 files: pr_files,
178 url: pr.html_url.to_string(),
179 labels,
180 review_comments: Vec::new(),
181 };
182
183 debug!(
184 file_count = details.files.len(),
185 "PR details fetched successfully"
186 );
187
188 Ok(details)
189}
190
191#[instrument(skip(client, files), fields(owner = %owner, repo = %repo, max_files = max_files))]
213async fn fetch_file_contents(
214 client: &Octocrab,
215 owner: &str,
216 repo: &str,
217 files: &[PrFile],
218 head_sha: &str,
219 max_files: usize,
220 max_chars_per_file: usize,
221) -> Vec<Option<String>> {
222 let mut results = Vec::with_capacity(files.len());
223 let mut fetched_count = 0usize;
224
225 for file in files {
226 if should_skip_file(&file.filename, &file.status, file.patch.as_ref()) {
227 results.push(None);
228 continue;
229 }
230
231 if fetched_count >= max_files {
233 debug!(
234 file = %file.filename,
235 fetched_count = fetched_count,
236 max_files = max_files,
237 "Fetched file count exceeds max_files cap"
238 );
239 results.push(None);
240 continue;
241 }
242
243 match client
245 .repos(owner, repo)
246 .get_content()
247 .path(&file.filename)
248 .r#ref(head_sha)
249 .send()
250 .await
251 {
252 Ok(content) => {
253 if let Some(item) = content.items.first() {
255 if let Some(decoded) = item.decoded_content() {
256 let truncated = if decoded.len() > max_chars_per_file {
257 decoded.chars().take(max_chars_per_file).collect::<String>()
258 } else {
259 decoded
260 };
261 debug!(
262 file = %file.filename,
263 content_len = truncated.len(),
264 "File content fetched and truncated"
265 );
266 results.push(Some(truncated));
267 fetched_count += 1;
268 } else {
269 tracing::warn!(
270 file = %file.filename,
271 "Failed to decode file content; skipping"
272 );
273 results.push(None);
274 }
275 } else {
276 tracing::warn!(
277 file = %file.filename,
278 "File content response was empty; skipping"
279 );
280 results.push(None);
281 }
282 }
283 Err(e) => {
284 tracing::warn!(
285 file = %file.filename,
286 err = %e,
287 "Failed to fetch file content; skipping"
288 );
289 results.push(None);
290 }
291 }
292 }
293
294 results
295}
296
297#[allow(clippy::too_many_arguments)]
321#[instrument(skip(client, comments), fields(owner = %owner, repo = %repo, number = number, event = %event))]
322pub async fn post_pr_review(
323 client: &Octocrab,
324 owner: &str,
325 repo: &str,
326 number: u64,
327 body: &str,
328 event: ReviewEvent,
329 comments: &[PrReviewComment],
330 commit_id: &str,
331) -> Result<u64> {
332 debug!("Posting PR review");
333
334 let route = format!("/repos/{owner}/{repo}/pulls/{number}/reviews");
335
336 let inline_comments: Vec<serde_json::Value> = comments
338 .iter()
339 .filter_map(|c| {
341 c.line.map(|line| {
342 serde_json::json!({
343 "path": c.file,
344 "line": line,
345 "side": "RIGHT",
349 "body": render_pr_review_comment_body(c),
350 })
351 })
352 })
353 .collect();
354
355 let mut payload = serde_json::json!({
356 "body": body,
357 "event": event.to_string(),
358 "comments": inline_comments,
359 });
360
361 if !commit_id.is_empty() {
363 payload["commit_id"] = serde_json::Value::String(commit_id.to_string());
364 }
365
366 #[derive(serde::Deserialize)]
367 struct ReviewResponse {
368 id: u64,
369 }
370
371 let response: ReviewResponse = client.post(route, Some(&payload)).await.with_context(|| {
372 format!(
373 "Failed to post review to PR #{number} in {owner}/{repo}. \
374 Check that you have write access to the repository."
375 )
376 })?;
377
378 debug!(review_id = response.id, "PR review posted successfully");
379
380 Ok(response.id)
381}
382
383#[instrument(skip(client), fields(owner = %owner, repo = %repo, comment_id = comment_id))]
390pub async fn delete_pr_review_comment(
391 client: &Octocrab,
392 owner: &str,
393 repo: &str,
394 comment_id: u64,
395) -> Result<()> {
396 debug!("Deleting PR review comment");
397
398 let route = format!("/repos/{owner}/{repo}/pulls/comments/{comment_id}");
399
400 let empty_body = serde_json::json!({});
402 let result: std::result::Result<serde_json::Value, _> =
403 client.delete(&route, Some(&empty_body)).await;
404
405 match result {
406 Ok(_) => {
407 debug!("PR review comment deleted successfully");
408 Ok(())
409 }
410 Err(e)
411 if let octocrab::Error::GitHub { source, .. } = &e
412 && source.status_code.as_u16() == 404 =>
413 {
414 debug!("PR review comment already deleted (404); treating as success");
415 Ok(())
416 }
417 Err(e) => {
418 Err(e).with_context(|| format!("Failed to delete PR review comment #{comment_id}"))
419 }
420 }
421}
422
423#[must_use]
435pub fn labels_from_pr_metadata(title: &str, file_paths: &[String]) -> Vec<String> {
436 let mut labels = std::collections::HashSet::new();
437
438 let prefix = title
441 .split(':')
442 .next()
443 .unwrap_or("")
444 .split('(')
445 .next()
446 .unwrap_or("")
447 .trim();
448
449 let type_label = match prefix {
451 "feat" | "perf" => Some("enhancement"),
452 "fix" => Some("bug"),
453 "docs" => Some("documentation"),
454 "refactor" => Some("refactor"),
455 _ => None,
456 };
457
458 if let Some(label) = type_label {
459 labels.insert(label.to_string());
460 }
461
462 for path in file_paths {
464 let scope = if path.starts_with("crates/aptu-cli/") {
465 Some("cli")
466 } else if path.starts_with("docs/") {
467 Some("documentation")
468 } else {
469 None
470 };
471
472 if let Some(label) = scope {
473 labels.insert(label.to_string());
474 }
475 }
476
477 labels.into_iter().collect()
478}
479
480#[instrument(skip(client), fields(owner = %owner, repo = %repo, head = %head_branch, base = %base_branch))]
500#[allow(clippy::too_many_arguments)]
501pub async fn create_pull_request(
502 client: &Octocrab,
503 owner: &str,
504 repo: &str,
505 title: &str,
506 head_branch: &str,
507 base_branch: &str,
508 body: Option<&str>,
509 draft: bool,
510) -> anyhow::Result<PrCreateResult> {
511 debug!("Creating pull request");
512
513 let pr = client
514 .pulls(owner, repo)
515 .create(title, head_branch, base_branch)
516 .body(body.unwrap_or_default())
517 .draft(draft)
518 .send()
519 .await
520 .with_context(|| {
521 format!("Failed to create PR in {owner}/{repo} ({head_branch} -> {base_branch})")
522 })?;
523
524 let result = PrCreateResult {
525 pr_number: pr.number,
526 url: pr.html_url.to_string(),
527 branch: pr.head.ref_field,
528 base: pr.base.ref_field,
529 title: pr.title.clone(),
530 draft: pr.draft.unwrap_or(false),
531 files_changed: u32::try_from(pr.changed_files).unwrap_or(u32::MAX),
532 additions: pr.additions,
533 deletions: pr.deletions,
534 };
535
536 debug!(
537 pr_number = result.pr_number,
538 "Pull request created successfully"
539 );
540
541 Ok(result)
542}
543
544fn should_skip_file(filename: &str, status: &str, patch: Option<&String>) -> bool {
548 if status.to_lowercase().contains("removed") {
549 debug!(file = %filename, "Skipping removed file");
550 return true;
551 }
552 if patch.is_none_or(String::is_empty) {
553 debug!(file = %filename, "Skipping file with empty patch");
554 return true;
555 }
556 false
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::ai::types::CommentSeverity;
563
564 fn decode_content(encoded: &str, max_chars: usize) -> Option<String> {
565 use base64::Engine;
566 let engine = base64::engine::general_purpose::STANDARD;
567 let decoded_bytes = engine.decode(encoded).ok()?;
568 let decoded_str = String::from_utf8(decoded_bytes).ok()?;
569
570 if decoded_str.len() <= max_chars {
571 Some(decoded_str)
572 } else {
573 Some(decoded_str.chars().take(max_chars).collect::<String>())
574 }
575 }
576
577 #[test]
578 fn test_pr_create_result_fields() {
579 let result = PrCreateResult {
581 pr_number: 42,
582 url: "https://github.com/owner/repo/pull/42".to_string(),
583 branch: "feat/my-feature".to_string(),
584 base: "main".to_string(),
585 title: "feat: add feature".to_string(),
586 draft: false,
587 files_changed: 3,
588 additions: 100,
589 deletions: 10,
590 };
591
592 assert_eq!(result.pr_number, 42);
594 assert_eq!(result.url, "https://github.com/owner/repo/pull/42");
595 assert_eq!(result.branch, "feat/my-feature");
596 assert_eq!(result.base, "main");
597 assert_eq!(result.title, "feat: add feature");
598 assert!(!result.draft);
599 assert_eq!(result.files_changed, 3);
600 assert_eq!(result.additions, 100);
601 assert_eq!(result.deletions, 10);
602 }
603
604 fn build_inline_comments(comments: &[PrReviewComment]) -> Vec<serde_json::Value> {
611 comments
612 .iter()
613 .filter_map(|c| {
614 c.line.map(|line| {
615 serde_json::json!({
616 "path": c.file,
617 "line": line,
618 "side": "RIGHT",
619 "body": render_pr_review_comment_body(c),
620 })
621 })
622 })
623 .collect()
624 }
625
626 #[test]
627 fn test_post_pr_review_payload_with_comments() {
628 let comments = vec![PrReviewComment {
630 file: "src/main.rs".to_string(),
631 line: Some(42),
632 comment: "Consider using a match here.".to_string(),
633 severity: CommentSeverity::Suggestion,
634 suggested_code: None,
635 }];
636
637 let inline = build_inline_comments(&comments);
639
640 assert_eq!(inline.len(), 1);
642 assert_eq!(inline[0]["path"], "src/main.rs");
643 assert_eq!(inline[0]["line"], 42);
644 assert_eq!(inline[0]["side"], "RIGHT");
645 assert_eq!(inline[0]["body"], "Consider using a match here.");
646 }
647
648 #[test]
649 fn test_post_pr_review_skips_none_line_comments() {
650 let comments = vec![
652 PrReviewComment {
653 file: "src/lib.rs".to_string(),
654 line: None,
655 comment: "General file comment.".to_string(),
656 severity: CommentSeverity::Info,
657 suggested_code: None,
658 },
659 PrReviewComment {
660 file: "src/lib.rs".to_string(),
661 line: Some(10),
662 comment: "Inline comment.".to_string(),
663 severity: CommentSeverity::Warning,
664 suggested_code: None,
665 },
666 ];
667
668 let inline = build_inline_comments(&comments);
670
671 assert_eq!(inline.len(), 1);
673 assert_eq!(inline[0]["line"], 10);
674 }
675
676 #[test]
677 fn test_post_pr_review_empty_comments() {
678 let comments: Vec<PrReviewComment> = vec![];
680
681 let inline = build_inline_comments(&comments);
683
684 assert!(inline.is_empty());
686 let serialized = serde_json::to_string(&inline).unwrap();
687 assert_eq!(serialized, "[]");
688 }
689
690 #[test]
697 fn test_parse_pr_reference_delegates_to_shared() {
698 let (owner, repo, number) =
699 parse_pr_reference("https://github.com/block/goose/pull/123", None).unwrap();
700 assert_eq!(owner, "block");
701 assert_eq!(repo, "goose");
702 assert_eq!(number, 123);
703 }
704
705 #[test]
706 fn test_title_prefix_to_label_mapping() {
707 let cases = vec![
708 (
709 "feat: add new feature",
710 vec!["enhancement"],
711 "feat should map to enhancement",
712 ),
713 ("fix: resolve bug", vec!["bug"], "fix should map to bug"),
714 (
715 "docs: update readme",
716 vec!["documentation"],
717 "docs should map to documentation",
718 ),
719 (
720 "refactor: improve code",
721 vec!["refactor"],
722 "refactor should map to refactor",
723 ),
724 (
725 "perf: optimize",
726 vec!["enhancement"],
727 "perf should map to enhancement",
728 ),
729 (
730 "chore: update deps",
731 vec![],
732 "chore should produce no labels",
733 ),
734 ];
735
736 for (title, expected_labels, msg) in cases {
737 let labels = labels_from_pr_metadata(title, &[]);
738 for expected in &expected_labels {
739 assert!(
740 labels.contains(&expected.to_string()),
741 "{msg}: expected '{expected}' in {labels:?}",
742 );
743 }
744 if expected_labels.is_empty() {
745 assert!(labels.is_empty(), "{msg}: expected empty, got {labels:?}");
746 }
747 }
748 }
749
750 #[test]
751 fn test_file_path_to_scope_mapping() {
752 let cases = vec![
753 (
754 "feat: cli",
755 vec!["crates/aptu-cli/src/main.rs"],
756 vec!["enhancement", "cli"],
757 "cli path should map to cli scope",
758 ),
759 (
760 "feat: docs",
761 vec!["docs/GITHUB_ACTION.md"],
762 vec!["enhancement", "documentation"],
763 "docs path should map to documentation scope",
764 ),
765 (
766 "feat: workflow",
767 vec![".github/workflows/test.yml"],
768 vec!["enhancement"],
769 "workflow path should be ignored",
770 ),
771 ];
772
773 for (title, paths, expected_labels, msg) in cases {
774 let labels = labels_from_pr_metadata(
775 title,
776 &paths
777 .iter()
778 .map(std::string::ToString::to_string)
779 .collect::<Vec<_>>(),
780 );
781 for expected in expected_labels {
782 assert!(
783 labels.contains(&expected.to_string()),
784 "{msg}: expected '{expected}' in {labels:?}",
785 );
786 }
787 }
788 }
789
790 #[test]
791 fn test_combined_title_and_paths() {
792 let labels = labels_from_pr_metadata(
793 "feat: multi",
794 &[
795 "crates/aptu-cli/src/main.rs".to_string(),
796 "docs/README.md".to_string(),
797 ],
798 );
799 assert!(
800 labels.contains(&"enhancement".to_string()),
801 "should include enhancement from feat prefix"
802 );
803 assert!(
804 labels.contains(&"cli".to_string()),
805 "should include cli from path"
806 );
807 assert!(
808 labels.contains(&"documentation".to_string()),
809 "should include documentation from path"
810 );
811 }
812
813 #[test]
814 fn test_no_match_returns_empty() {
815 let cases = vec![
816 (
817 "Random title",
818 vec![],
819 "unrecognized prefix should return empty",
820 ),
821 (
822 "chore: update",
823 vec![],
824 "ignored prefix should return empty",
825 ),
826 ];
827
828 for (title, paths, msg) in cases {
829 let labels = labels_from_pr_metadata(title, &paths);
830 assert!(labels.is_empty(), "{msg}: got {labels:?}");
831 }
832 }
833
834 #[test]
835 fn test_scoped_prefix_extracts_type() {
836 let labels = labels_from_pr_metadata("feat(cli): add new feature", &[]);
837 assert!(
838 labels.contains(&"enhancement".to_string()),
839 "scoped prefix should extract type from feat(cli)"
840 );
841 }
842
843 #[test]
844 fn test_duplicate_labels_deduplicated() {
845 let labels = labels_from_pr_metadata("docs: update", &["docs/README.md".to_string()]);
846 assert_eq!(
847 labels.len(),
848 1,
849 "should have exactly one label when title and path both map to documentation"
850 );
851 assert!(
852 labels.contains(&"documentation".to_string()),
853 "should contain documentation label"
854 );
855 }
856
857 #[test]
858 fn test_should_skip_file_respects_fetched_count_cap() {
859 let removed_file = PrFile {
862 filename: "removed.rs".to_string(),
863 status: "removed".to_string(),
864 additions: 0,
865 deletions: 5,
866 patch: None,
867 full_content: None,
868 };
869 let modified_file = PrFile {
870 filename: "file_0.rs".to_string(),
871 status: "modified".to_string(),
872 additions: 1,
873 deletions: 0,
874 patch: Some("+ new code".to_string()),
875 full_content: None,
876 };
877 let no_patch_file = PrFile {
878 filename: "file_1.rs".to_string(),
879 status: "modified".to_string(),
880 additions: 1,
881 deletions: 0,
882 patch: None,
883 full_content: None,
884 };
885
886 assert!(
888 should_skip_file(
889 &removed_file.filename,
890 &removed_file.status,
891 removed_file.patch.as_ref()
892 ),
893 "removed files should be skipped"
894 );
895
896 assert!(
898 !should_skip_file(
899 &modified_file.filename,
900 &modified_file.status,
901 modified_file.patch.as_ref()
902 ),
903 "modified files with patch should not be skipped"
904 );
905
906 assert!(
908 should_skip_file(
909 &no_patch_file.filename,
910 &no_patch_file.status,
911 no_patch_file.patch.as_ref()
912 ),
913 "files without patch should be skipped"
914 );
915 }
916
917 #[test]
918 fn test_decode_content_valid_base64() {
919 use base64::Engine;
921 let engine = base64::engine::general_purpose::STANDARD;
922 let original = "Hello, World!";
923 let encoded = engine.encode(original);
924
925 let result = decode_content(&encoded, 1000);
927
928 assert_eq!(
930 result,
931 Some(original.to_string()),
932 "valid base64 should decode successfully"
933 );
934 }
935
936 #[test]
937 fn test_decode_content_invalid_base64() {
938 let invalid_base64 = "!!!invalid!!!";
940
941 let result = decode_content(invalid_base64, 1000);
943
944 assert_eq!(result, None, "invalid base64 should return None");
946 }
947
948 #[test]
949 fn test_decode_content_truncates_at_max_chars() {
950 use base64::Engine;
952 let engine = base64::engine::general_purpose::STANDARD;
953 let original = "こんにちは".repeat(10); let encoded = engine.encode(&original);
955 let max_chars = 10;
956
957 let result = decode_content(&encoded, max_chars);
959
960 assert!(result.is_some(), "decoding should succeed");
962 let decoded = result.unwrap();
963 assert_eq!(
964 decoded.chars().count(),
965 max_chars,
966 "output should be truncated to max_chars on character boundary"
967 );
968 assert!(
969 decoded.is_char_boundary(decoded.len()),
970 "output should be valid UTF-8 (truncated on char boundary)"
971 );
972 }
973}