1use anyhow::{Context, Result};
9use octocrab::Octocrab;
10use tracing::{debug, instrument};
11
12use super::{ReferenceKind, parse_github_reference};
13use crate::ai::types::{PrDetails, PrFile, PrReviewComment, ReviewEvent};
14use crate::error::{AptuError, ResourceType};
15use crate::triage::render_pr_review_comment_body;
16
17#[derive(Debug, serde::Serialize)]
19pub struct PrCreateResult {
20 pub pr_number: u64,
22 pub url: String,
24 pub branch: String,
26 pub base: String,
28 pub title: String,
30 pub draft: bool,
32 pub files_changed: u32,
34 pub additions: u64,
36 pub deletions: u64,
38}
39
40pub fn parse_pr_reference(
60 reference: &str,
61 repo_context: Option<&str>,
62) -> Result<(String, String, u64)> {
63 parse_github_reference(ReferenceKind::Pull, reference, repo_context)
64}
65
66#[instrument(skip(client), fields(owner = %owner, repo = %repo, number = number))]
85pub async fn fetch_pr_details(
86 client: &Octocrab,
87 owner: &str,
88 repo: &str,
89 number: u64,
90 review_config: &crate::config::ReviewConfig,
91) -> Result<PrDetails> {
92 debug!("Fetching PR details");
93
94 let pr = match client.pulls(owner, repo).get(number).await {
96 Ok(pr) => pr,
97 Err(e) => {
98 if let octocrab::Error::GitHub { source, .. } = &e
100 && source.status_code == 404
101 {
102 if (client.issues(owner, repo).get(number).await).is_ok() {
104 return Err(AptuError::TypeMismatch {
105 number,
106 expected: ResourceType::PullRequest,
107 actual: ResourceType::Issue,
108 }
109 .into());
110 }
111 }
113 return Err(e)
114 .with_context(|| format!("Failed to fetch PR #{number} from {owner}/{repo}"));
115 }
116 };
117
118 let files = client
120 .pulls(owner, repo)
121 .list_files(number)
122 .await
123 .with_context(|| format!("Failed to fetch files for PR #{number}"))?;
124
125 let pr_files: Vec<PrFile> = files
127 .items
128 .into_iter()
129 .map(|f| PrFile {
130 filename: f.filename,
131 status: format!("{:?}", f.status),
132 additions: f.additions,
133 deletions: f.deletions,
134 patch: f.patch,
135 full_content: None,
136 })
137 .collect();
138
139 let file_contents = fetch_file_contents(
141 client,
142 owner,
143 repo,
144 &pr_files,
145 &pr.head.sha,
146 review_config.max_full_content_files,
147 review_config.max_chars_per_file,
148 )
149 .await;
150
151 debug_assert_eq!(
153 pr_files.len(),
154 file_contents.len(),
155 "fetch_file_contents must return one entry per file"
156 );
157 let pr_files: Vec<PrFile> = pr_files
158 .into_iter()
159 .zip(file_contents)
160 .map(|(mut file, content)| {
161 file.full_content = content;
162 file
163 })
164 .collect();
165
166 let labels: Vec<String> = pr
167 .labels
168 .iter()
169 .flat_map(|labels_vec| labels_vec.iter().map(|l| l.name.clone()))
170 .collect();
171
172 let details = PrDetails {
173 owner: owner.to_string(),
174 repo: repo.to_string(),
175 number,
176 title: pr.title.unwrap_or_default(),
177 body: pr.body.unwrap_or_default(),
178 base_branch: pr.base.ref_field,
179 head_branch: pr.head.ref_field,
180 head_sha: pr.head.sha,
181 files: pr_files,
182 url: pr.html_url.map_or_else(String::new, |u| u.to_string()),
183 labels,
184 };
185
186 debug!(
187 file_count = details.files.len(),
188 "PR details fetched successfully"
189 );
190
191 Ok(details)
192}
193
194#[instrument(skip(client, files), fields(owner = %owner, repo = %repo, max_files = max_files))]
216async fn fetch_file_contents(
217 client: &Octocrab,
218 owner: &str,
219 repo: &str,
220 files: &[PrFile],
221 head_sha: &str,
222 max_files: usize,
223 max_chars_per_file: usize,
224) -> Vec<Option<String>> {
225 let mut results = Vec::with_capacity(files.len());
226 let mut fetched_count = 0usize;
227
228 for file in files {
229 if file.status.to_lowercase().contains("removed") {
231 debug!(file = %file.filename, "Skipping removed file");
232 results.push(None);
233 continue;
234 }
235
236 if file.patch.as_ref().is_none_or(String::is_empty) {
238 debug!(file = %file.filename, "Skipping file with empty patch");
239 results.push(None);
240 continue;
241 }
242
243 if fetched_count >= max_files {
245 debug!(
246 file = %file.filename,
247 fetched_count = fetched_count,
248 max_files = max_files,
249 "Fetched file count exceeds max_files cap"
250 );
251 results.push(None);
252 continue;
253 }
254
255 match client
257 .repos(owner, repo)
258 .get_content()
259 .path(&file.filename)
260 .r#ref(head_sha)
261 .send()
262 .await
263 {
264 Ok(content) => {
265 if let Some(item) = content.items.first() {
267 if let Some(decoded) = item.decoded_content() {
268 let truncated = if decoded.len() > max_chars_per_file {
269 decoded.chars().take(max_chars_per_file).collect::<String>()
270 } else {
271 decoded
272 };
273 debug!(
274 file = %file.filename,
275 content_len = truncated.len(),
276 "File content fetched and truncated"
277 );
278 results.push(Some(truncated));
279 fetched_count += 1;
280 } else {
281 tracing::warn!(
282 file = %file.filename,
283 "Failed to decode file content; skipping"
284 );
285 results.push(None);
286 }
287 } else {
288 tracing::warn!(
289 file = %file.filename,
290 "File content response was empty; skipping"
291 );
292 results.push(None);
293 }
294 }
295 Err(e) => {
296 tracing::warn!(
297 file = %file.filename,
298 err = %e,
299 "Failed to fetch file content; skipping"
300 );
301 results.push(None);
302 }
303 }
304 }
305
306 results
307}
308
309#[allow(clippy::too_many_arguments)]
333#[instrument(skip(client, comments), fields(owner = %owner, repo = %repo, number = number, event = %event))]
334pub async fn post_pr_review(
335 client: &Octocrab,
336 owner: &str,
337 repo: &str,
338 number: u64,
339 body: &str,
340 event: ReviewEvent,
341 comments: &[PrReviewComment],
342 commit_id: &str,
343) -> Result<u64> {
344 debug!("Posting PR review");
345
346 let route = format!("/repos/{owner}/{repo}/pulls/{number}/reviews");
347
348 let inline_comments: Vec<serde_json::Value> = comments
350 .iter()
351 .filter_map(|c| {
353 c.line.map(|line| {
354 serde_json::json!({
355 "path": c.file,
356 "line": line,
357 "side": "RIGHT",
361 "body": render_pr_review_comment_body(c),
362 })
363 })
364 })
365 .collect();
366
367 let mut payload = serde_json::json!({
368 "body": body,
369 "event": event.to_string(),
370 "comments": inline_comments,
371 });
372
373 if !commit_id.is_empty() {
375 payload["commit_id"] = serde_json::Value::String(commit_id.to_string());
376 }
377
378 #[derive(serde::Deserialize)]
379 struct ReviewResponse {
380 id: u64,
381 }
382
383 let response: ReviewResponse = client.post(route, Some(&payload)).await.with_context(|| {
384 format!(
385 "Failed to post review to PR #{number} in {owner}/{repo}. \
386 Check that you have write access to the repository."
387 )
388 })?;
389
390 debug!(review_id = response.id, "PR review posted successfully");
391
392 Ok(response.id)
393}
394
395#[must_use]
407pub fn labels_from_pr_metadata(title: &str, file_paths: &[String]) -> Vec<String> {
408 let mut labels = std::collections::HashSet::new();
409
410 let prefix = title
413 .split(':')
414 .next()
415 .unwrap_or("")
416 .split('(')
417 .next()
418 .unwrap_or("")
419 .trim();
420
421 let type_label = match prefix {
423 "feat" | "perf" => Some("enhancement"),
424 "fix" => Some("bug"),
425 "docs" => Some("documentation"),
426 "refactor" => Some("refactor"),
427 _ => None,
428 };
429
430 if let Some(label) = type_label {
431 labels.insert(label.to_string());
432 }
433
434 for path in file_paths {
436 let scope = if path.starts_with("crates/aptu-cli/") {
437 Some("cli")
438 } else if path.starts_with("crates/aptu-ffi/") || path.starts_with("AptuApp/") {
439 Some("ios")
440 } else if path.starts_with("docs/") {
441 Some("documentation")
442 } else if path.starts_with("snap/") {
443 Some("distribution")
444 } else {
445 None
446 };
447
448 if let Some(label) = scope {
449 labels.insert(label.to_string());
450 }
451 }
452
453 labels.into_iter().collect()
454}
455
456#[instrument(skip(client), fields(owner = %owner, repo = %repo, head = %head_branch, base = %base_branch))]
476pub async fn create_pull_request(
477 client: &Octocrab,
478 owner: &str,
479 repo: &str,
480 title: &str,
481 head_branch: &str,
482 base_branch: &str,
483 body: Option<&str>,
484) -> anyhow::Result<PrCreateResult> {
485 debug!("Creating pull request");
486
487 let pr = client
488 .pulls(owner, repo)
489 .create(title, head_branch, base_branch)
490 .body(body.unwrap_or_default())
491 .draft(false)
492 .send()
493 .await
494 .with_context(|| {
495 format!("Failed to create PR in {owner}/{repo} ({head_branch} -> {base_branch})")
496 })?;
497
498 let result = PrCreateResult {
499 pr_number: pr.number,
500 url: pr.html_url.map_or_else(String::new, |u| u.to_string()),
501 branch: pr.head.ref_field,
502 base: pr.base.ref_field,
503 title: pr.title.unwrap_or_default(),
504 draft: pr.draft.unwrap_or(false),
505 files_changed: u32::try_from(pr.changed_files.unwrap_or_default()).unwrap_or(u32::MAX),
506 additions: pr.additions.unwrap_or_default(),
507 deletions: pr.deletions.unwrap_or_default(),
508 };
509
510 debug!(
511 pr_number = result.pr_number,
512 "Pull request created successfully"
513 );
514
515 Ok(result)
516}
517
518#[inline]
521#[allow(dead_code)]
522fn should_skip_file(status: &str, patch: Option<&String>) -> bool {
523 status.to_lowercase().contains("removed") || patch.is_none_or(String::is_empty)
524}
525
526#[allow(dead_code)]
529fn decode_content(encoded: &str, max_chars: usize) -> Option<String> {
530 use base64::Engine;
531 let engine = base64::engine::general_purpose::STANDARD;
532 let decoded_bytes = engine.decode(encoded).ok()?;
533 let decoded_str = String::from_utf8(decoded_bytes).ok()?;
534
535 if decoded_str.len() <= max_chars {
536 Some(decoded_str)
537 } else {
538 Some(decoded_str.chars().take(max_chars).collect::<String>())
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use crate::ai::types::CommentSeverity;
546
547 #[test]
548 fn test_pr_create_result_fields() {
549 let result = PrCreateResult {
551 pr_number: 42,
552 url: "https://github.com/owner/repo/pull/42".to_string(),
553 branch: "feat/my-feature".to_string(),
554 base: "main".to_string(),
555 title: "feat: add feature".to_string(),
556 draft: false,
557 files_changed: 3,
558 additions: 100,
559 deletions: 10,
560 };
561
562 assert_eq!(result.pr_number, 42);
564 assert_eq!(result.url, "https://github.com/owner/repo/pull/42");
565 assert_eq!(result.branch, "feat/my-feature");
566 assert_eq!(result.base, "main");
567 assert_eq!(result.title, "feat: add feature");
568 assert!(!result.draft);
569 assert_eq!(result.files_changed, 3);
570 assert_eq!(result.additions, 100);
571 assert_eq!(result.deletions, 10);
572 }
573
574 fn build_inline_comments(comments: &[PrReviewComment]) -> Vec<serde_json::Value> {
581 comments
582 .iter()
583 .filter_map(|c| {
584 c.line.map(|line| {
585 serde_json::json!({
586 "path": c.file,
587 "line": line,
588 "side": "RIGHT",
589 "body": render_pr_review_comment_body(c),
590 })
591 })
592 })
593 .collect()
594 }
595
596 #[test]
597 fn test_post_pr_review_payload_with_comments() {
598 let comments = vec![PrReviewComment {
600 file: "src/main.rs".to_string(),
601 line: Some(42),
602 comment: "Consider using a match here.".to_string(),
603 severity: CommentSeverity::Suggestion,
604 suggested_code: None,
605 }];
606
607 let inline = build_inline_comments(&comments);
609
610 assert_eq!(inline.len(), 1);
612 assert_eq!(inline[0]["path"], "src/main.rs");
613 assert_eq!(inline[0]["line"], 42);
614 assert_eq!(inline[0]["side"], "RIGHT");
615 assert_eq!(inline[0]["body"], "Consider using a match here.");
616 }
617
618 #[test]
619 fn test_post_pr_review_skips_none_line_comments() {
620 let comments = vec![
622 PrReviewComment {
623 file: "src/lib.rs".to_string(),
624 line: None,
625 comment: "General file comment.".to_string(),
626 severity: CommentSeverity::Info,
627 suggested_code: None,
628 },
629 PrReviewComment {
630 file: "src/lib.rs".to_string(),
631 line: Some(10),
632 comment: "Inline comment.".to_string(),
633 severity: CommentSeverity::Warning,
634 suggested_code: None,
635 },
636 ];
637
638 let inline = build_inline_comments(&comments);
640
641 assert_eq!(inline.len(), 1);
643 assert_eq!(inline[0]["line"], 10);
644 }
645
646 #[test]
647 fn test_post_pr_review_empty_comments() {
648 let comments: Vec<PrReviewComment> = vec![];
650
651 let inline = build_inline_comments(&comments);
653
654 assert!(inline.is_empty());
656 let serialized = serde_json::to_string(&inline).unwrap();
657 assert_eq!(serialized, "[]");
658 }
659
660 #[test]
667 fn test_parse_pr_reference_delegates_to_shared() {
668 let (owner, repo, number) =
669 parse_pr_reference("https://github.com/block/goose/pull/123", None).unwrap();
670 assert_eq!(owner, "block");
671 assert_eq!(repo, "goose");
672 assert_eq!(number, 123);
673 }
674
675 #[test]
676 fn test_title_prefix_to_label_mapping() {
677 let cases = vec![
678 (
679 "feat: add new feature",
680 vec!["enhancement"],
681 "feat should map to enhancement",
682 ),
683 ("fix: resolve bug", vec!["bug"], "fix should map to bug"),
684 (
685 "docs: update readme",
686 vec!["documentation"],
687 "docs should map to documentation",
688 ),
689 (
690 "refactor: improve code",
691 vec!["refactor"],
692 "refactor should map to refactor",
693 ),
694 (
695 "perf: optimize",
696 vec!["enhancement"],
697 "perf should map to enhancement",
698 ),
699 (
700 "chore: update deps",
701 vec![],
702 "chore should produce no labels",
703 ),
704 ];
705
706 for (title, expected_labels, msg) in cases {
707 let labels = labels_from_pr_metadata(title, &[]);
708 for expected in &expected_labels {
709 assert!(
710 labels.contains(&expected.to_string()),
711 "{msg}: expected '{expected}' in {labels:?}",
712 );
713 }
714 if expected_labels.is_empty() {
715 assert!(labels.is_empty(), "{msg}: expected empty, got {labels:?}",);
716 }
717 }
718 }
719
720 #[test]
721 fn test_file_path_to_scope_mapping() {
722 let cases = vec![
723 (
724 "feat: cli",
725 vec!["crates/aptu-cli/src/main.rs"],
726 vec!["enhancement", "cli"],
727 "cli path should map to cli scope",
728 ),
729 (
730 "feat: ios",
731 vec!["crates/aptu-ffi/src/lib.rs"],
732 vec!["enhancement", "ios"],
733 "ffi path should map to ios scope",
734 ),
735 (
736 "feat: ios",
737 vec!["AptuApp/ContentView.swift"],
738 vec!["enhancement", "ios"],
739 "app path should map to ios scope",
740 ),
741 (
742 "feat: docs",
743 vec!["docs/GITHUB_ACTION.md"],
744 vec!["enhancement", "documentation"],
745 "docs path should map to documentation scope",
746 ),
747 (
748 "feat: snap",
749 vec!["snap/snapcraft.yaml"],
750 vec!["enhancement", "distribution"],
751 "snap path should map to distribution scope",
752 ),
753 (
754 "feat: workflow",
755 vec![".github/workflows/test.yml"],
756 vec!["enhancement"],
757 "workflow path should be ignored",
758 ),
759 ];
760
761 for (title, paths, expected_labels, msg) in cases {
762 let labels = labels_from_pr_metadata(
763 title,
764 &paths
765 .iter()
766 .map(std::string::ToString::to_string)
767 .collect::<Vec<_>>(),
768 );
769 for expected in expected_labels {
770 assert!(
771 labels.contains(&expected.to_string()),
772 "{msg}: expected '{expected}' in {labels:?}",
773 );
774 }
775 }
776 }
777
778 #[test]
779 fn test_combined_title_and_paths() {
780 let labels = labels_from_pr_metadata(
781 "feat: multi",
782 &[
783 "crates/aptu-cli/src/main.rs".to_string(),
784 "docs/README.md".to_string(),
785 ],
786 );
787 assert!(
788 labels.contains(&"enhancement".to_string()),
789 "should include enhancement from feat prefix"
790 );
791 assert!(
792 labels.contains(&"cli".to_string()),
793 "should include cli from path"
794 );
795 assert!(
796 labels.contains(&"documentation".to_string()),
797 "should include documentation from path"
798 );
799 }
800
801 #[test]
802 fn test_no_match_returns_empty() {
803 let cases = vec![
804 (
805 "Random title",
806 vec![],
807 "unrecognized prefix should return empty",
808 ),
809 (
810 "chore: update",
811 vec![],
812 "ignored prefix should return empty",
813 ),
814 ];
815
816 for (title, paths, msg) in cases {
817 let labels = labels_from_pr_metadata(title, &paths);
818 assert!(labels.is_empty(), "{msg}: got {labels:?}");
819 }
820 }
821
822 #[test]
823 fn test_scoped_prefix_extracts_type() {
824 let labels = labels_from_pr_metadata("feat(cli): add new feature", &[]);
825 assert!(
826 labels.contains(&"enhancement".to_string()),
827 "scoped prefix should extract type from feat(cli)"
828 );
829 }
830
831 #[test]
832 fn test_duplicate_labels_deduplicated() {
833 let labels = labels_from_pr_metadata("docs: update", &["docs/README.md".to_string()]);
834 assert_eq!(
835 labels.len(),
836 1,
837 "should have exactly one label when title and path both map to documentation"
838 );
839 assert!(
840 labels.contains(&"documentation".to_string()),
841 "should contain documentation label"
842 );
843 }
844
845 #[test]
846 fn test_should_skip_file_respects_fetched_count_cap() {
847 let removed_file = PrFile {
850 filename: "removed.rs".to_string(),
851 status: "removed".to_string(),
852 additions: 0,
853 deletions: 5,
854 patch: None,
855 full_content: None,
856 };
857 let modified_file = PrFile {
858 filename: "file_0.rs".to_string(),
859 status: "modified".to_string(),
860 additions: 1,
861 deletions: 0,
862 patch: Some("+ new code".to_string()),
863 full_content: None,
864 };
865 let no_patch_file = PrFile {
866 filename: "file_1.rs".to_string(),
867 status: "modified".to_string(),
868 additions: 1,
869 deletions: 0,
870 patch: None,
871 full_content: None,
872 };
873
874 assert!(
876 should_skip_file(&removed_file.status, removed_file.patch.as_ref()),
877 "removed files should be skipped"
878 );
879
880 assert!(
882 !should_skip_file(&modified_file.status, modified_file.patch.as_ref()),
883 "modified files with patch should not be skipped"
884 );
885
886 assert!(
888 should_skip_file(&no_patch_file.status, no_patch_file.patch.as_ref()),
889 "files without patch should be skipped"
890 );
891 }
892
893 #[test]
894 fn test_decode_content_valid_base64() {
895 use base64::Engine;
897 let engine = base64::engine::general_purpose::STANDARD;
898 let original = "Hello, World!";
899 let encoded = engine.encode(original);
900
901 let result = decode_content(&encoded, 1000);
903
904 assert_eq!(
906 result,
907 Some(original.to_string()),
908 "valid base64 should decode successfully"
909 );
910 }
911
912 #[test]
913 fn test_decode_content_invalid_base64() {
914 let invalid_base64 = "!!!invalid!!!";
916
917 let result = decode_content(invalid_base64, 1000);
919
920 assert_eq!(result, None, "invalid base64 should return None");
922 }
923
924 #[test]
925 fn test_decode_content_truncates_at_max_chars() {
926 use base64::Engine;
928 let engine = base64::engine::general_purpose::STANDARD;
929 let original = "こんにちは".repeat(10); let encoded = engine.encode(&original);
931 let max_chars = 10;
932
933 let result = decode_content(&encoded, max_chars);
935
936 assert!(result.is_some(), "decoding should succeed");
938 let decoded = result.unwrap();
939 assert_eq!(
940 decoded.chars().count(),
941 max_chars,
942 "output should be truncated to max_chars on character boundary"
943 );
944 assert!(
945 decoded.is_char_boundary(decoded.len()),
946 "output should be valid UTF-8 (truncated on char boundary)"
947 );
948 }
949}