1use anyhow::{Context, Result};
9use octocrab::Octocrab;
10use tracing::{debug, instrument};
11
12use super::{ReferenceKind, parse_github_reference};
13use crate::ai::types::{PrDetails, PrFile, PrReviewComment, ReviewEvent};
14use crate::error::{AptuError, ResourceType};
15use crate::triage::render_pr_review_comment_body;
16
17#[derive(Debug, serde::Serialize)]
19pub struct PrCreateResult {
20 pub pr_number: u64,
22 pub url: String,
24 pub branch: String,
26 pub base: String,
28 pub title: String,
30 pub draft: bool,
32 pub files_changed: u32,
34 pub additions: u64,
36 pub deletions: u64,
38}
39
40pub fn parse_pr_reference(
60 reference: &str,
61 repo_context: Option<&str>,
62) -> Result<(String, String, u64)> {
63 parse_github_reference(ReferenceKind::Pull, reference, repo_context)
64}
65
66#[instrument(skip(client), fields(owner = %owner, repo = %repo, number = number))]
85pub async fn fetch_pr_details(
86 client: &Octocrab,
87 owner: &str,
88 repo: &str,
89 number: u64,
90 review_config: &crate::config::ReviewConfig,
91) -> Result<PrDetails> {
92 debug!("Fetching PR details");
93
94 let pr = match client.pulls(owner, repo).get(number).await {
96 Ok(pr) => pr,
97 Err(e) => {
98 if let octocrab::Error::GitHub { source, .. } = &e
100 && source.status_code == 404
101 {
102 if (client.issues(owner, repo).get(number).await).is_ok() {
104 return Err(AptuError::TypeMismatch {
105 number,
106 expected: ResourceType::PullRequest,
107 actual: ResourceType::Issue,
108 }
109 .into());
110 }
111 }
113 return Err(e)
114 .with_context(|| format!("Failed to fetch PR #{number} from {owner}/{repo}"));
115 }
116 };
117
118 let files = client
120 .pulls(owner, repo)
121 .list_files(number)
122 .await
123 .with_context(|| format!("Failed to fetch files for PR #{number}"))?;
124
125 let pr_files: Vec<PrFile> = files
127 .items
128 .into_iter()
129 .map(|f| PrFile {
130 filename: f.filename,
131 status: format!("{:?}", f.status),
132 additions: f.additions,
133 deletions: f.deletions,
134 patch: f.patch,
135 full_content: None,
136 })
137 .collect();
138
139 let file_contents = fetch_file_contents(
141 client,
142 owner,
143 repo,
144 &pr_files,
145 &pr.head.sha,
146 review_config.max_full_content_files,
147 review_config.max_chars_per_file,
148 )
149 .await;
150
151 debug_assert_eq!(
153 pr_files.len(),
154 file_contents.len(),
155 "fetch_file_contents must return one entry per file"
156 );
157 let pr_files: Vec<PrFile> = pr_files
158 .into_iter()
159 .zip(file_contents)
160 .map(|(mut file, content)| {
161 file.full_content = content;
162 file
163 })
164 .collect();
165
166 let labels: Vec<String> = pr
167 .labels
168 .iter()
169 .flat_map(|labels_vec| labels_vec.iter().map(|l| l.name.clone()))
170 .collect();
171
172 let details = PrDetails {
173 owner: owner.to_string(),
174 repo: repo.to_string(),
175 number,
176 title: pr.title.unwrap_or_default(),
177 body: pr.body.unwrap_or_default(),
178 base_branch: pr.base.ref_field,
179 head_branch: pr.head.ref_field,
180 head_sha: pr.head.sha,
181 files: pr_files,
182 url: pr.html_url.map_or_else(String::new, |u| u.to_string()),
183 labels,
184 };
185
186 debug!(
187 file_count = details.files.len(),
188 "PR details fetched successfully"
189 );
190
191 Ok(details)
192}
193
194#[instrument(skip(client, files), fields(owner = %owner, repo = %repo, max_files = max_files))]
216async fn fetch_file_contents(
217 client: &Octocrab,
218 owner: &str,
219 repo: &str,
220 files: &[PrFile],
221 head_sha: &str,
222 max_files: usize,
223 max_chars_per_file: usize,
224) -> Vec<Option<String>> {
225 let mut results = Vec::with_capacity(files.len());
226 let mut fetched_count = 0usize;
227
228 for file in files {
229 if should_skip_file(&file.filename, &file.status, file.patch.as_ref()) {
230 results.push(None);
231 continue;
232 }
233
234 if fetched_count >= max_files {
236 debug!(
237 file = %file.filename,
238 fetched_count = fetched_count,
239 max_files = max_files,
240 "Fetched file count exceeds max_files cap"
241 );
242 results.push(None);
243 continue;
244 }
245
246 match client
248 .repos(owner, repo)
249 .get_content()
250 .path(&file.filename)
251 .r#ref(head_sha)
252 .send()
253 .await
254 {
255 Ok(content) => {
256 if let Some(item) = content.items.first() {
258 if let Some(decoded) = item.decoded_content() {
259 let truncated = if decoded.len() > max_chars_per_file {
260 decoded.chars().take(max_chars_per_file).collect::<String>()
261 } else {
262 decoded
263 };
264 debug!(
265 file = %file.filename,
266 content_len = truncated.len(),
267 "File content fetched and truncated"
268 );
269 results.push(Some(truncated));
270 fetched_count += 1;
271 } else {
272 tracing::warn!(
273 file = %file.filename,
274 "Failed to decode file content; skipping"
275 );
276 results.push(None);
277 }
278 } else {
279 tracing::warn!(
280 file = %file.filename,
281 "File content response was empty; skipping"
282 );
283 results.push(None);
284 }
285 }
286 Err(e) => {
287 tracing::warn!(
288 file = %file.filename,
289 err = %e,
290 "Failed to fetch file content; skipping"
291 );
292 results.push(None);
293 }
294 }
295 }
296
297 results
298}
299
300#[allow(clippy::too_many_arguments)]
324#[instrument(skip(client, comments), fields(owner = %owner, repo = %repo, number = number, event = %event))]
325pub async fn post_pr_review(
326 client: &Octocrab,
327 owner: &str,
328 repo: &str,
329 number: u64,
330 body: &str,
331 event: ReviewEvent,
332 comments: &[PrReviewComment],
333 commit_id: &str,
334) -> Result<u64> {
335 debug!("Posting PR review");
336
337 let route = format!("/repos/{owner}/{repo}/pulls/{number}/reviews");
338
339 let inline_comments: Vec<serde_json::Value> = comments
341 .iter()
342 .filter_map(|c| {
344 c.line.map(|line| {
345 serde_json::json!({
346 "path": c.file,
347 "line": line,
348 "side": "RIGHT",
352 "body": render_pr_review_comment_body(c),
353 })
354 })
355 })
356 .collect();
357
358 let mut payload = serde_json::json!({
359 "body": body,
360 "event": event.to_string(),
361 "comments": inline_comments,
362 });
363
364 if !commit_id.is_empty() {
366 payload["commit_id"] = serde_json::Value::String(commit_id.to_string());
367 }
368
369 #[derive(serde::Deserialize)]
370 struct ReviewResponse {
371 id: u64,
372 }
373
374 let response: ReviewResponse = client.post(route, Some(&payload)).await.with_context(|| {
375 format!(
376 "Failed to post review to PR #{number} in {owner}/{repo}. \
377 Check that you have write access to the repository."
378 )
379 })?;
380
381 debug!(review_id = response.id, "PR review posted successfully");
382
383 Ok(response.id)
384}
385
386#[must_use]
398pub fn labels_from_pr_metadata(title: &str, file_paths: &[String]) -> Vec<String> {
399 let mut labels = std::collections::HashSet::new();
400
401 let prefix = title
404 .split(':')
405 .next()
406 .unwrap_or("")
407 .split('(')
408 .next()
409 .unwrap_or("")
410 .trim();
411
412 let type_label = match prefix {
414 "feat" | "perf" => Some("enhancement"),
415 "fix" => Some("bug"),
416 "docs" => Some("documentation"),
417 "refactor" => Some("refactor"),
418 _ => None,
419 };
420
421 if let Some(label) = type_label {
422 labels.insert(label.to_string());
423 }
424
425 for path in file_paths {
427 let scope = if path.starts_with("crates/aptu-cli/") {
428 Some("cli")
429 } else if path.starts_with("crates/aptu-ffi/") || path.starts_with("AptuApp/") {
430 Some("ios")
431 } else if path.starts_with("docs/") {
432 Some("documentation")
433 } else {
434 None
435 };
436
437 if let Some(label) = scope {
438 labels.insert(label.to_string());
439 }
440 }
441
442 labels.into_iter().collect()
443}
444
445#[instrument(skip(client), fields(owner = %owner, repo = %repo, head = %head_branch, base = %base_branch))]
465#[allow(clippy::too_many_arguments)]
466pub async fn create_pull_request(
467 client: &Octocrab,
468 owner: &str,
469 repo: &str,
470 title: &str,
471 head_branch: &str,
472 base_branch: &str,
473 body: Option<&str>,
474 draft: bool,
475) -> anyhow::Result<PrCreateResult> {
476 debug!("Creating pull request");
477
478 let pr = client
479 .pulls(owner, repo)
480 .create(title, head_branch, base_branch)
481 .body(body.unwrap_or_default())
482 .draft(draft)
483 .send()
484 .await
485 .with_context(|| {
486 format!("Failed to create PR in {owner}/{repo} ({head_branch} -> {base_branch})")
487 })?;
488
489 let result = PrCreateResult {
490 pr_number: pr.number,
491 url: pr.html_url.map_or_else(String::new, |u| u.to_string()),
492 branch: pr.head.ref_field,
493 base: pr.base.ref_field,
494 title: pr.title.unwrap_or_default(),
495 draft: pr.draft.unwrap_or(false),
496 files_changed: u32::try_from(pr.changed_files.unwrap_or_default()).unwrap_or(u32::MAX),
497 additions: pr.additions.unwrap_or_default(),
498 deletions: pr.deletions.unwrap_or_default(),
499 };
500
501 debug!(
502 pr_number = result.pr_number,
503 "Pull request created successfully"
504 );
505
506 Ok(result)
507}
508
509fn should_skip_file(filename: &str, status: &str, patch: Option<&String>) -> bool {
513 if status.to_lowercase().contains("removed") {
514 debug!(file = %filename, "Skipping removed file");
515 return true;
516 }
517 if patch.is_none_or(String::is_empty) {
518 debug!(file = %filename, "Skipping file with empty patch");
519 return true;
520 }
521 false
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::ai::types::CommentSeverity;
528
529 fn decode_content(encoded: &str, max_chars: usize) -> Option<String> {
530 use base64::Engine;
531 let engine = base64::engine::general_purpose::STANDARD;
532 let decoded_bytes = engine.decode(encoded).ok()?;
533 let decoded_str = String::from_utf8(decoded_bytes).ok()?;
534
535 if decoded_str.len() <= max_chars {
536 Some(decoded_str)
537 } else {
538 Some(decoded_str.chars().take(max_chars).collect::<String>())
539 }
540 }
541
542 #[test]
543 fn test_pr_create_result_fields() {
544 let result = PrCreateResult {
546 pr_number: 42,
547 url: "https://github.com/owner/repo/pull/42".to_string(),
548 branch: "feat/my-feature".to_string(),
549 base: "main".to_string(),
550 title: "feat: add feature".to_string(),
551 draft: false,
552 files_changed: 3,
553 additions: 100,
554 deletions: 10,
555 };
556
557 assert_eq!(result.pr_number, 42);
559 assert_eq!(result.url, "https://github.com/owner/repo/pull/42");
560 assert_eq!(result.branch, "feat/my-feature");
561 assert_eq!(result.base, "main");
562 assert_eq!(result.title, "feat: add feature");
563 assert!(!result.draft);
564 assert_eq!(result.files_changed, 3);
565 assert_eq!(result.additions, 100);
566 assert_eq!(result.deletions, 10);
567 }
568
569 fn build_inline_comments(comments: &[PrReviewComment]) -> Vec<serde_json::Value> {
576 comments
577 .iter()
578 .filter_map(|c| {
579 c.line.map(|line| {
580 serde_json::json!({
581 "path": c.file,
582 "line": line,
583 "side": "RIGHT",
584 "body": render_pr_review_comment_body(c),
585 })
586 })
587 })
588 .collect()
589 }
590
591 #[test]
592 fn test_post_pr_review_payload_with_comments() {
593 let comments = vec![PrReviewComment {
595 file: "src/main.rs".to_string(),
596 line: Some(42),
597 comment: "Consider using a match here.".to_string(),
598 severity: CommentSeverity::Suggestion,
599 suggested_code: None,
600 }];
601
602 let inline = build_inline_comments(&comments);
604
605 assert_eq!(inline.len(), 1);
607 assert_eq!(inline[0]["path"], "src/main.rs");
608 assert_eq!(inline[0]["line"], 42);
609 assert_eq!(inline[0]["side"], "RIGHT");
610 assert_eq!(inline[0]["body"], "Consider using a match here.");
611 }
612
613 #[test]
614 fn test_post_pr_review_skips_none_line_comments() {
615 let comments = vec![
617 PrReviewComment {
618 file: "src/lib.rs".to_string(),
619 line: None,
620 comment: "General file comment.".to_string(),
621 severity: CommentSeverity::Info,
622 suggested_code: None,
623 },
624 PrReviewComment {
625 file: "src/lib.rs".to_string(),
626 line: Some(10),
627 comment: "Inline comment.".to_string(),
628 severity: CommentSeverity::Warning,
629 suggested_code: None,
630 },
631 ];
632
633 let inline = build_inline_comments(&comments);
635
636 assert_eq!(inline.len(), 1);
638 assert_eq!(inline[0]["line"], 10);
639 }
640
641 #[test]
642 fn test_post_pr_review_empty_comments() {
643 let comments: Vec<PrReviewComment> = vec![];
645
646 let inline = build_inline_comments(&comments);
648
649 assert!(inline.is_empty());
651 let serialized = serde_json::to_string(&inline).unwrap();
652 assert_eq!(serialized, "[]");
653 }
654
655 #[test]
662 fn test_parse_pr_reference_delegates_to_shared() {
663 let (owner, repo, number) =
664 parse_pr_reference("https://github.com/block/goose/pull/123", None).unwrap();
665 assert_eq!(owner, "block");
666 assert_eq!(repo, "goose");
667 assert_eq!(number, 123);
668 }
669
670 #[test]
671 fn test_title_prefix_to_label_mapping() {
672 let cases = vec![
673 (
674 "feat: add new feature",
675 vec!["enhancement"],
676 "feat should map to enhancement",
677 ),
678 ("fix: resolve bug", vec!["bug"], "fix should map to bug"),
679 (
680 "docs: update readme",
681 vec!["documentation"],
682 "docs should map to documentation",
683 ),
684 (
685 "refactor: improve code",
686 vec!["refactor"],
687 "refactor should map to refactor",
688 ),
689 (
690 "perf: optimize",
691 vec!["enhancement"],
692 "perf should map to enhancement",
693 ),
694 (
695 "chore: update deps",
696 vec![],
697 "chore should produce no labels",
698 ),
699 ];
700
701 for (title, expected_labels, msg) in cases {
702 let labels = labels_from_pr_metadata(title, &[]);
703 for expected in &expected_labels {
704 assert!(
705 labels.contains(&expected.to_string()),
706 "{msg}: expected '{expected}' in {labels:?}",
707 );
708 }
709 if expected_labels.is_empty() {
710 assert!(labels.is_empty(), "{msg}: expected empty, got {labels:?}",);
711 }
712 }
713 }
714
715 #[test]
716 fn test_file_path_to_scope_mapping() {
717 let cases = vec![
718 (
719 "feat: cli",
720 vec!["crates/aptu-cli/src/main.rs"],
721 vec!["enhancement", "cli"],
722 "cli path should map to cli scope",
723 ),
724 (
725 "feat: ios",
726 vec!["crates/aptu-ffi/src/lib.rs"],
727 vec!["enhancement", "ios"],
728 "ffi path should map to ios scope",
729 ),
730 (
731 "feat: ios",
732 vec!["AptuApp/ContentView.swift"],
733 vec!["enhancement", "ios"],
734 "app path should map to ios scope",
735 ),
736 (
737 "feat: docs",
738 vec!["docs/GITHUB_ACTION.md"],
739 vec!["enhancement", "documentation"],
740 "docs path should map to documentation scope",
741 ),
742 (
743 "feat: workflow",
744 vec![".github/workflows/test.yml"],
745 vec!["enhancement"],
746 "workflow path should be ignored",
747 ),
748 ];
749
750 for (title, paths, expected_labels, msg) in cases {
751 let labels = labels_from_pr_metadata(
752 title,
753 &paths
754 .iter()
755 .map(std::string::ToString::to_string)
756 .collect::<Vec<_>>(),
757 );
758 for expected in expected_labels {
759 assert!(
760 labels.contains(&expected.to_string()),
761 "{msg}: expected '{expected}' in {labels:?}",
762 );
763 }
764 }
765 }
766
767 #[test]
768 fn test_combined_title_and_paths() {
769 let labels = labels_from_pr_metadata(
770 "feat: multi",
771 &[
772 "crates/aptu-cli/src/main.rs".to_string(),
773 "docs/README.md".to_string(),
774 ],
775 );
776 assert!(
777 labels.contains(&"enhancement".to_string()),
778 "should include enhancement from feat prefix"
779 );
780 assert!(
781 labels.contains(&"cli".to_string()),
782 "should include cli from path"
783 );
784 assert!(
785 labels.contains(&"documentation".to_string()),
786 "should include documentation from path"
787 );
788 }
789
790 #[test]
791 fn test_no_match_returns_empty() {
792 let cases = vec![
793 (
794 "Random title",
795 vec![],
796 "unrecognized prefix should return empty",
797 ),
798 (
799 "chore: update",
800 vec![],
801 "ignored prefix should return empty",
802 ),
803 ];
804
805 for (title, paths, msg) in cases {
806 let labels = labels_from_pr_metadata(title, &paths);
807 assert!(labels.is_empty(), "{msg}: got {labels:?}");
808 }
809 }
810
811 #[test]
812 fn test_scoped_prefix_extracts_type() {
813 let labels = labels_from_pr_metadata("feat(cli): add new feature", &[]);
814 assert!(
815 labels.contains(&"enhancement".to_string()),
816 "scoped prefix should extract type from feat(cli)"
817 );
818 }
819
820 #[test]
821 fn test_duplicate_labels_deduplicated() {
822 let labels = labels_from_pr_metadata("docs: update", &["docs/README.md".to_string()]);
823 assert_eq!(
824 labels.len(),
825 1,
826 "should have exactly one label when title and path both map to documentation"
827 );
828 assert!(
829 labels.contains(&"documentation".to_string()),
830 "should contain documentation label"
831 );
832 }
833
834 #[test]
835 fn test_should_skip_file_respects_fetched_count_cap() {
836 let removed_file = PrFile {
839 filename: "removed.rs".to_string(),
840 status: "removed".to_string(),
841 additions: 0,
842 deletions: 5,
843 patch: None,
844 full_content: None,
845 };
846 let modified_file = PrFile {
847 filename: "file_0.rs".to_string(),
848 status: "modified".to_string(),
849 additions: 1,
850 deletions: 0,
851 patch: Some("+ new code".to_string()),
852 full_content: None,
853 };
854 let no_patch_file = PrFile {
855 filename: "file_1.rs".to_string(),
856 status: "modified".to_string(),
857 additions: 1,
858 deletions: 0,
859 patch: None,
860 full_content: None,
861 };
862
863 assert!(
865 should_skip_file(
866 &removed_file.filename,
867 &removed_file.status,
868 removed_file.patch.as_ref()
869 ),
870 "removed files should be skipped"
871 );
872
873 assert!(
875 !should_skip_file(
876 &modified_file.filename,
877 &modified_file.status,
878 modified_file.patch.as_ref()
879 ),
880 "modified files with patch should not be skipped"
881 );
882
883 assert!(
885 should_skip_file(
886 &no_patch_file.filename,
887 &no_patch_file.status,
888 no_patch_file.patch.as_ref()
889 ),
890 "files without patch should be skipped"
891 );
892 }
893
894 #[test]
895 fn test_decode_content_valid_base64() {
896 use base64::Engine;
898 let engine = base64::engine::general_purpose::STANDARD;
899 let original = "Hello, World!";
900 let encoded = engine.encode(original);
901
902 let result = decode_content(&encoded, 1000);
904
905 assert_eq!(
907 result,
908 Some(original.to_string()),
909 "valid base64 should decode successfully"
910 );
911 }
912
913 #[test]
914 fn test_decode_content_invalid_base64() {
915 let invalid_base64 = "!!!invalid!!!";
917
918 let result = decode_content(invalid_base64, 1000);
920
921 assert_eq!(result, None, "invalid base64 should return None");
923 }
924
925 #[test]
926 fn test_decode_content_truncates_at_max_chars() {
927 use base64::Engine;
929 let engine = base64::engine::general_purpose::STANDARD;
930 let original = "こんにちは".repeat(10); let encoded = engine.encode(&original);
932 let max_chars = 10;
933
934 let result = decode_content(&encoded, max_chars);
936
937 assert!(result.is_some(), "decoding should succeed");
939 let decoded = result.unwrap();
940 assert_eq!(
941 decoded.chars().count(),
942 max_chars,
943 "output should be truncated to max_chars on character boundary"
944 );
945 assert!(
946 decoded.is_char_boundary(decoded.len()),
947 "output should be valid UTF-8 (truncated on char boundary)"
948 );
949 }
950}