1use serde::{Deserialize, Serialize};
2
3use crate::error::{Error, Result};
4
5#[derive(Debug, Serialize, Deserialize)]
6pub struct PullRequest {
7 pub number: u32,
8 pub title: String,
9 pub state: String,
10 pub html_url: String,
11 pub draft: bool,
12}
13
14#[derive(Debug, Deserialize)]
16struct GhPrResponse {
17 number: u32,
18 title: String,
19 state: String,
20 url: String,
21 #[serde(rename = "isDraft")]
22 is_draft: bool,
23}
24
25#[derive(Debug, Deserialize)]
26struct GhPrWithBranchResponse {
27 number: u32,
28 title: String,
29 state: String,
30 url: String,
31 #[serde(rename = "isDraft")]
32 is_draft: bool,
33 #[serde(rename = "headRefName")]
34 head_ref_name: String,
35}
36
37pub struct GitHubClient;
38
39impl Default for GitHubClient {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl GitHubClient {
46 pub fn new() -> Self {
47 Self
48 }
49
50 fn get_gh_token() -> Option<String> {
51 std::process::Command::new("gh")
52 .args(["auth", "token"])
53 .output()
54 .ok()
55 .and_then(|output| {
56 if output.status.success() {
57 String::from_utf8(output.stdout)
58 .ok()
59 .map(|s| s.trim().to_string())
60 .filter(|s| !s.is_empty())
61 } else {
62 None
63 }
64 })
65 }
66
67 pub fn has_auth(&self) -> bool {
68 Self::get_gh_token().is_some()
69 }
70
71 pub fn get_pull_requests(&self, owner: &str, repo: &str, branch: &str) -> Result<Vec<PullRequest>> {
72 let output = std::process::Command::new("gh")
74 .args([
75 "pr",
76 "list",
77 "--repo",
78 &format!("{}/{}", owner, repo),
79 "--head",
80 branch,
81 "--state",
82 "all",
83 "--json",
84 "number,title,state,url,isDraft",
85 ])
86 .output()
87 .map_err(|e| Error::provider(format!("Failed to execute gh command: {}", e)))?;
88
89 if !output.status.success() {
90 let stderr = String::from_utf8_lossy(&output.stderr);
91 if stderr.contains("not authenticated") || stderr.contains("authentication") {
92 return Err(Error::auth(
93 "GitHub authentication failed. Run 'gh auth login' to authenticate.",
94 ));
95 }
96 return Err(Error::provider(format!("Failed to fetch pull requests: {}", stderr)));
97 }
98
99 let stdout = String::from_utf8(output.stdout)?;
100 if stdout.trim().is_empty() {
101 return Ok(vec![]);
102 }
103
104 let prs: Vec<GhPrResponse> = serde_json::from_str(&stdout)
105 .map_err(|e| Error::provider(format!("Failed to parse pull requests from gh output: {}", e)))?;
106
107 Ok(prs
108 .into_iter()
109 .map(|pr| PullRequest {
110 number: pr.number,
111 title: pr.title,
112 state: pr.state,
113 html_url: pr.url,
114 draft: pr.is_draft,
115 })
116 .collect())
117 }
118
119 pub fn get_all_pull_requests(&self, owner: &str, repo: &str) -> Result<Vec<(PullRequest, String)>> {
120 let output = std::process::Command::new("gh")
122 .args([
123 "pr",
124 "list",
125 "--repo",
126 &format!("{}/{}", owner, repo),
127 "--state",
128 "open",
129 "--json",
130 "number,title,state,url,isDraft,headRefName",
131 "--limit",
132 "100",
133 ])
134 .output()
135 .map_err(|e| Error::provider(format!("Failed to execute gh command: {}", e)))?;
136
137 if !output.status.success() {
138 let stderr = String::from_utf8_lossy(&output.stderr);
139 if stderr.contains("not authenticated") || stderr.contains("authentication") {
140 return Err(Error::auth(
141 "GitHub authentication failed. Run 'gh auth login' to authenticate.",
142 ));
143 }
144 return Err(Error::provider(format!("Failed to fetch pull requests: {}", stderr)));
145 }
146
147 let stdout = String::from_utf8(output.stdout)?;
148 if stdout.trim().is_empty() {
149 return Ok(vec![]);
150 }
151
152 let prs: Vec<GhPrWithBranchResponse> = serde_json::from_str(&stdout)
153 .map_err(|e| Error::provider(format!("Failed to parse pull requests from gh output: {}", e)))?;
154
155 Ok(prs
156 .into_iter()
157 .map(|pr| {
158 let pull_request = PullRequest {
159 number: pr.number,
160 title: pr.title,
161 state: pr.state,
162 html_url: pr.url,
163 draft: pr.is_draft,
164 };
165 (pull_request, pr.head_ref_name)
166 })
167 .collect())
168 }
169
170 pub fn parse_github_url(url: &str) -> Option<(String, String)> {
171 if let Some(captures) = url.strip_prefix("https://github.com/") {
173 let parts: Vec<&str> = captures.trim_end_matches(".git").split('/').collect();
174 if parts.len() >= 2 {
175 return Some((parts[0].to_string(), parts[1].to_string()));
176 }
177 } else if let Some(captures) = url.strip_prefix("git@github.com:") {
178 let parts: Vec<&str> = captures.trim_end_matches(".git").split('/').collect();
179 if parts.len() >= 2 {
180 return Some((parts[0].to_string(), parts[1].to_string()));
181 }
182 }
183 None
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_parse_github_url() {
193 let test_cases = vec![
194 (
195 "https://github.com/owner/repo.git",
196 Some(("owner".to_string(), "repo".to_string())),
197 ),
198 (
199 "https://github.com/owner/repo",
200 Some(("owner".to_string(), "repo".to_string())),
201 ),
202 (
203 "git@github.com:owner/repo.git",
204 Some(("owner".to_string(), "repo".to_string())),
205 ),
206 (
207 "git@github.com:owner/repo",
208 Some(("owner".to_string(), "repo".to_string())),
209 ),
210 ("https://gitlab.com/owner/repo", None),
211 ];
212
213 for (url, expected) in test_cases {
214 assert_eq!(GitHubClient::parse_github_url(url), expected);
215 }
216 }
217}