1use async_trait::async_trait;
11use oxi_sdk::SdkError;
12use oxi_sdk::ports::{ProtocolHandler, ResolveContext, ResolvedUrl};
13use serde::Deserialize;
14
15use super::{detect_github_repo, github_token};
16use crate::util::http_client::shared_http_client;
17
18#[derive(Debug, Clone, Default)]
20pub struct PrProtocolHandler;
21
22struct PrUrl {
24 owner: String,
25 repo: String,
26 pr_number: u64,
27 diff: bool,
29}
30
31#[derive(Debug, Deserialize)]
33struct GhPr {
34 number: u64,
35 title: String,
36 body: Option<String>,
37 state: String,
38 user: Option<GhUser>,
39 labels: Option<Vec<GhLabel>>,
40 created_at: Option<String>,
41 merged_at: Option<String>,
42 closed_at: Option<String>,
43 draft: Option<bool>,
44 head: Option<GhRef>,
45 base: Option<GhRef>,
46 mergeable: Option<bool>,
47}
48
49#[derive(Debug, Deserialize)]
50struct GhUser {
51 login: String,
52}
53
54#[derive(Debug, Deserialize)]
55struct GhLabel {
56 name: String,
57}
58
59#[derive(Debug, Deserialize)]
60struct GhRef {
61 #[serde(rename = "ref")]
62 ref_name: String,
63}
64
65impl PrProtocolHandler {
66 fn parse_url(url: &str) -> Result<PrUrl, SdkError> {
68 let url = url.trim();
69 if url.is_empty() {
70 return Err(SdkError::Internal(anyhow::anyhow!("empty PR URL")));
71 }
72
73 let parts: Vec<&str> = url.split('/').collect();
74
75 let (core_parts, wants_diff) = if parts.len() >= 5 && parts[parts.len() - 2] == "diff" {
77 (&parts[..parts.len() - 2], true)
79 } else {
80 (parts.as_slice(), false)
81 };
82
83 match core_parts.len() {
84 1 => {
85 let pr_number: u64 = core_parts[0].parse().map_err(|_| {
87 SdkError::Internal(anyhow::anyhow!("invalid PR number: {}", core_parts[0]))
88 })?;
89 let repo = detect_github_repo().ok_or_else(|| {
90 SdkError::Internal(anyhow::anyhow!(
91 "could not detect GitHub repo from git remote; use owner/repo/N format"
92 ))
93 })?;
94 let (owner, repo_name) = split_owner_repo(&repo)?;
95 Ok(PrUrl {
96 owner,
97 repo: repo_name,
98 pr_number,
99 diff: wants_diff,
100 })
101 }
102 2 => {
103 Err(SdkError::Internal(anyhow::anyhow!(
105 "PR URL requires a number: {url} (use owner/repo/N)"
106 )))
107 }
108 3 => {
109 let pr_number: u64 = core_parts[2].parse().map_err(|_| {
111 SdkError::Internal(anyhow::anyhow!("invalid PR number: {}", core_parts[2]))
112 })?;
113 Ok(PrUrl {
114 owner: core_parts[0].to_string(),
115 repo: core_parts[1].to_string(),
116 pr_number,
117 diff: wants_diff,
118 })
119 }
120 _ => Err(SdkError::Internal(anyhow::anyhow!(
121 "invalid PR URL format: {url}"
122 ))),
123 }
124 }
125}
126
127fn split_owner_repo(repo: &str) -> Result<(String, String), SdkError> {
128 let parts: Vec<&str> = repo.split('/').collect();
129 if parts.len() != 2 {
130 return Err(SdkError::Internal(anyhow::anyhow!(
131 "invalid repo format (expected owner/repo): {repo}"
132 )));
133 }
134 Ok((parts[0].to_string(), parts[1].to_string()))
135}
136
137#[async_trait]
138impl ProtocolHandler for PrProtocolHandler {
139 fn scheme(&self) -> &str {
140 "pr"
141 }
142
143 async fn resolve(
144 &self,
145 url: &str,
146 _selector: Option<&str>,
147 _ctx: &ResolveContext,
148 ) -> Result<ResolvedUrl, SdkError> {
149 let parsed = Self::parse_url(url)?;
150
151 let client = shared_http_client();
152 let token = github_token();
153
154 let pr_api_url = format!(
156 "https://api.github.com/repos/{}/{}/pulls/{}",
157 parsed.owner, parsed.repo, parsed.pr_number
158 );
159
160 let mut request = client
161 .get(&pr_api_url)
162 .header("User-Agent", "oxi-cli")
163 .header("Accept", "application/vnd.github.v3+json");
164
165 if let Some(ref t) = token {
166 request = request.header("Authorization", format!("Bearer {}", t));
167 }
168
169 let response = request
170 .send()
171 .await
172 .map_err(|e| SdkError::Internal(anyhow::anyhow!("GitHub API request failed: {e}")))?;
173
174 if !response.status().is_success() {
175 let status = response.status();
176 let body = response.text().await.unwrap_or_default();
177 return Err(SdkError::Internal(anyhow::anyhow!(
178 "GitHub API returned {status}: {body}"
179 )));
180 }
181
182 let pr: GhPr = response.json().await.map_err(|e| {
183 SdkError::Internal(anyhow::anyhow!("failed to parse GitHub API response: {e}"))
184 })?;
185
186 let mut md = format_pr_markdown(&pr);
187
188 if parsed.diff {
190 md.push_str("\n\n## Diff\n\n");
191 match fetch_pr_diff(&client, &parsed, token.as_deref()).await {
192 Ok(diff) => {
193 md.push_str("```diff\n");
194 md.push_str(&diff);
195 md.push_str("\n```\n");
196 }
197 Err(e) => {
198 md.push_str(&format!("*Failed to fetch diff: {e}*\n"));
199 }
200 }
201 }
202
203 Ok(ResolvedUrl {
204 url: format!(
205 "https://github.com/{}/{}/pull/{}",
206 parsed.owner, parsed.repo, parsed.pr_number
207 ),
208 content: md,
209 content_type: "text/markdown".into(),
210 size: None,
211 source_path: None,
212 notes: vec![],
213 immutable: false,
214 })
215 }
216}
217
218async fn fetch_pr_diff(
219 client: &reqwest::Client,
220 pr: &PrUrl,
221 token: Option<&str>,
222) -> Result<String, SdkError> {
223 let diff_url = format!(
224 "https://api.github.com/repos/{}/{}/pulls/{}",
225 pr.owner, pr.repo, pr.pr_number
226 );
227
228 let mut request = client
229 .get(&diff_url)
230 .header("User-Agent", "oxi-cli")
231 .header("Accept", "application/vnd.github.v3.diff");
232
233 if let Some(t) = token {
234 request = request.header("Authorization", format!("Bearer {}", t));
235 }
236
237 let response = request
238 .send()
239 .await
240 .map_err(|e| SdkError::Internal(anyhow::anyhow!("GitHub diff request failed: {e}")))?;
241
242 if !response.status().is_success() {
243 let status = response.status();
244 let body = response.text().await.unwrap_or_default();
245 return Err(SdkError::Internal(anyhow::anyhow!(
246 "GitHub diff API returned {status}: {body}"
247 )));
248 }
249
250 response
251 .text()
252 .await
253 .map_err(|e| SdkError::Internal(anyhow::anyhow!("failed to read diff response: {e}")))
254}
255
256fn format_pr_markdown(pr: &GhPr) -> String {
257 let mut md = format!("# PR #{}: {}\n\n", pr.number, pr.title);
258
259 let state_label = match (pr.state.as_str(), pr.draft) {
261 (_, Some(true)) => "📝 Draft",
262 ("open", _) => "🟢 Open",
263 ("closed", _) => "🔴 Closed",
264 ("merged", _) => "🟣 Merged",
265 (other, _) => other,
266 };
267 md.push_str(&format!("**State:** {}\n\n", state_label));
268
269 if let Some(ref user) = pr.user {
271 md.push_str(&format!("**Author:** @{}\n\n", user.login));
272 }
273
274 if let Some(ref head) = pr.head {
276 if let Some(ref base) = pr.base {
277 md.push_str(&format!(
278 "**Branch:** `{}` → `{}`\n\n",
279 head.ref_name, base.ref_name
280 ));
281 }
282 }
283
284 if let Some(ref labels) = pr.labels {
286 if !labels.is_empty() {
287 let label_names: Vec<&str> = labels.iter().map(|l| l.name.as_str()).collect();
288 md.push_str(&format!("**Labels:** {}\n\n", label_names.join(", ")));
289 }
290 }
291
292 if let Some(mergeable) = pr.mergeable {
294 md.push_str(&format!(
295 "**Mergeable:** {}\n\n",
296 if mergeable { "✅ Yes" } else { "❌ No" }
297 ));
298 }
299
300 if let Some(ref created) = pr.created_at {
302 md.push_str(&format!("**Created:** {}\n", created));
303 }
304 if let Some(ref merged) = pr.merged_at {
305 md.push_str(&format!("**Merged:** {}\n", merged));
306 } else if let Some(ref closed) = pr.closed_at {
307 md.push_str(&format!("**Closed:** {}\n", closed));
308 }
309
310 md.push('\n');
311
312 if let Some(ref body) = pr.body {
314 if !body.is_empty() {
315 md.push_str("---\n\n");
316 md.push_str(body);
317 md.push('\n');
318 }
319 }
320
321 md
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_parse_url_n() {
330 let result = PrProtocolHandler::parse_url("42").unwrap();
331 }
335
336 #[test]
337 fn test_parse_url_owner_repo_n() {
338 let result = PrProtocolHandler::parse_url("rust-lang/rust/12345").unwrap();
339 assert_eq!(result.owner, "rust-lang");
340 assert_eq!(result.repo, "rust");
341 assert_eq!(result.pr_number, 12345);
342 assert!(!result.diff);
343 }
344
345 #[test]
346 fn test_parse_url_owner_repo_n_diff() {
347 let result = PrProtocolHandler::parse_url("rust-lang/rust/12345/diff/0").unwrap();
348 assert_eq!(result.owner, "rust-lang");
349 assert_eq!(result.repo, "rust");
350 assert_eq!(result.pr_number, 12345);
351 assert!(result.diff);
352 }
353
354 #[test]
355 fn test_parse_url_rejects_two_parts() {
356 let result = PrProtocolHandler::parse_url("owner/repo");
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_parse_url_rejects_empty() {
362 let result = PrProtocolHandler::parse_url("");
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn test_format_pr_markdown() {
368 let pr = GhPr {
369 number: 42,
370 title: "Add new feature".into(),
371 body: Some("Implements the new widget system.".into()),
372 state: "open".into(),
373 user: Some(GhUser {
374 login: "coder".into(),
375 }),
376 labels: Some(vec![GhLabel {
377 name: "enhancement".into(),
378 }]),
379 created_at: Some("2026-01-15T12:00:00Z".into()),
380 merged_at: None,
381 closed_at: None,
382 draft: Some(false),
383 head: Some(GhRef {
384 ref_name: "feature/widget".into(),
385 }),
386 base: Some(GhRef {
387 ref_name: "main".into(),
388 }),
389 mergeable: Some(true),
390 };
391
392 let md = format_pr_markdown(&pr);
393 assert!(md.contains("# PR #42: Add new feature"));
394 assert!(md.contains("🟢 Open"));
395 assert!(md.contains("@coder"));
396 assert!(md.contains("`feature/widget` → `main`"));
397 assert!(md.contains("enhancement"));
398 assert!(md.contains("✅ Yes"));
399 assert!(md.contains("Implements the new widget system"));
400 }
401
402 #[test]
403 fn test_format_pr_draft() {
404 let pr = GhPr {
405 number: 1,
406 title: "Draft PR".into(),
407 body: None,
408 state: "open".into(),
409 user: None,
410 labels: None,
411 created_at: None,
412 merged_at: None,
413 closed_at: None,
414 draft: Some(true),
415 head: None,
416 base: None,
417 mergeable: None,
418 };
419
420 let md = format_pr_markdown(&pr);
421 assert!(md.contains("📝 Draft"));
422 }
423}