1use anyhow::{Context, Result, bail};
9use octocrab::Octocrab;
10use tracing::{debug, instrument};
11
12use crate::ai::types::{PrDetails, PrFile, ReviewEvent};
13
14pub fn parse_pr_reference(
34 reference: &str,
35 repo_context: Option<&str>,
36) -> Result<(String, String, u64)> {
37 let reference = reference.trim();
38
39 if reference.starts_with("https://github.com/") || reference.starts_with("http://github.com/") {
42 let path = reference
43 .trim_start_matches("https://github.com/")
44 .trim_start_matches("http://github.com/");
45
46 let parts: Vec<&str> = path.split('/').collect();
47 if parts.len() >= 4 && parts[2] == "pull" {
48 let owner = parts[0].to_string();
49 let repo = parts[1].to_string();
50 let number: u64 = parts[3]
51 .parse()
52 .with_context(|| format!("Invalid PR number in URL: {}", parts[3]))?;
53 return Ok((owner, repo, number));
54 }
55 bail!("Invalid GitHub PR URL format: {reference}");
56 }
57
58 if let Some((repo_part, num_part)) = reference.split_once('#') {
60 if let Some((owner, repo)) = repo_part.split_once('/') {
61 let number: u64 = num_part
62 .parse()
63 .with_context(|| format!("Invalid PR number: {num_part}"))?;
64 return Ok((owner.to_string(), repo.to_string(), number));
65 }
66 if let Some(ctx) = repo_context
68 && let Some((owner, repo)) = ctx.split_once('/')
69 {
70 let number: u64 = num_part
71 .parse()
72 .with_context(|| format!("Invalid PR number: {num_part}"))?;
73 return Ok((owner.to_string(), repo.to_string(), number));
74 }
75 bail!("Invalid PR reference format: {reference}");
76 }
77
78 if let Ok(number) = reference.parse::<u64>() {
80 if let Some(ctx) = repo_context {
81 if let Some((owner, repo)) = ctx.split_once('/') {
82 return Ok((owner.to_string(), repo.to_string(), number));
83 }
84 bail!("Invalid repo_context format, expected 'owner/repo': {ctx}");
85 }
86 bail!("Bare PR number requires --repo flag or default_repo config: {reference}");
87 }
88
89 bail!(
90 "Invalid PR reference format: {reference}. Expected URL, owner/repo#number, or number with --repo"
91 )
92}
93
94#[instrument(skip(client), fields(owner = %owner, repo = %repo, number = number))]
113pub async fn fetch_pr_details(
114 client: &Octocrab,
115 owner: &str,
116 repo: &str,
117 number: u64,
118) -> Result<PrDetails> {
119 debug!("Fetching PR details");
120
121 let pr = client
123 .pulls(owner, repo)
124 .get(number)
125 .await
126 .with_context(|| format!("Failed to fetch PR #{number} from {owner}/{repo}"))?;
127
128 let files = client
130 .pulls(owner, repo)
131 .list_files(number)
132 .await
133 .with_context(|| format!("Failed to fetch files for PR #{number}"))?;
134
135 let pr_files: Vec<PrFile> = files
137 .items
138 .into_iter()
139 .map(|f| PrFile {
140 filename: f.filename,
141 status: format!("{:?}", f.status),
142 additions: f.additions,
143 deletions: f.deletions,
144 patch: f.patch,
145 })
146 .collect();
147
148 let details = PrDetails {
149 owner: owner.to_string(),
150 repo: repo.to_string(),
151 number,
152 title: pr.title.unwrap_or_default(),
153 body: pr.body.unwrap_or_default(),
154 base_branch: pr.base.ref_field,
155 head_branch: pr.head.ref_field,
156 files: pr_files,
157 url: pr.html_url.map_or_else(String::new, |u| u.to_string()),
158 };
159
160 debug!(
161 file_count = details.files.len(),
162 "PR details fetched successfully"
163 );
164
165 Ok(details)
166}
167
168#[instrument(skip(client), fields(owner = %owner, repo = %repo, number = number, event = %event))]
190pub async fn post_pr_review(
191 client: &Octocrab,
192 owner: &str,
193 repo: &str,
194 number: u64,
195 body: &str,
196 event: ReviewEvent,
197) -> Result<u64> {
198 debug!("Posting PR review");
199
200 let route = format!("repos/{owner}/{repo}/pulls/{number}/reviews");
201
202 let payload = serde_json::json!({
203 "body": body,
204 "event": event.to_string(),
205 });
206
207 #[derive(serde::Deserialize)]
208 struct ReviewResponse {
209 id: u64,
210 }
211
212 let response: ReviewResponse = client.post(route, Some(&payload)).await.with_context(|| {
213 format!(
214 "Failed to post review to PR #{number} in {owner}/{repo}. \
215 Check that you have write access to the repository."
216 )
217 })?;
218
219 debug!(review_id = response.id, "PR review posted successfully");
220
221 Ok(response.id)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_parse_pr_reference_full_url() {
230 let (owner, repo, number) =
231 parse_pr_reference("https://github.com/block/goose/pull/123", None).unwrap();
232 assert_eq!(owner, "block");
233 assert_eq!(repo, "goose");
234 assert_eq!(number, 123);
235 }
236
237 #[test]
238 fn test_parse_pr_reference_short_form() {
239 let (owner, repo, number) = parse_pr_reference("block/goose#456", None).unwrap();
240 assert_eq!(owner, "block");
241 assert_eq!(repo, "goose");
242 assert_eq!(number, 456);
243 }
244
245 #[test]
246 fn test_parse_pr_reference_bare_number_with_context() {
247 let (owner, repo, number) = parse_pr_reference("789", Some("block/goose")).unwrap();
248 assert_eq!(owner, "block");
249 assert_eq!(repo, "goose");
250 assert_eq!(number, 789);
251 }
252
253 #[test]
254 fn test_parse_pr_reference_bare_number_without_context() {
255 let result = parse_pr_reference("123", None);
256 assert!(result.is_err());
257 assert!(
258 result
259 .unwrap_err()
260 .to_string()
261 .contains("requires --repo flag")
262 );
263 }
264
265 #[test]
266 fn test_parse_pr_reference_hash_with_context() {
267 let (owner, repo, number) = parse_pr_reference("#42", Some("owner/repo")).unwrap();
268 assert_eq!(owner, "owner");
269 assert_eq!(repo, "repo");
270 assert_eq!(number, 42);
271 }
272
273 #[test]
274 fn test_parse_pr_reference_invalid_url() {
275 let result = parse_pr_reference("https://github.com/invalid", None);
276 assert!(result.is_err());
277 }
278
279 #[test]
280 fn test_parse_pr_reference_invalid_number() {
281 let result = parse_pr_reference("block/goose#abc", None);
282 assert!(result.is_err());
283 }
284}