Skip to main content

mockforge_intelligence/pr_generation/
github.rs

1//! GitHub PR client
2//!
3//! This module provides functionality for creating pull requests on GitHub.
4
5use crate::pr_generation::types::{PRFileChange, PRFileChangeType, PRRequest, PRResult};
6use mockforge_foundation::{Error, Result};
7use reqwest::Client;
8
9/// GitHub PR client
10#[derive(Debug, Clone)]
11pub struct GitHubPRClient {
12    owner: String,
13    repo: String,
14    token: String,
15    base_branch: String,
16    client: Client,
17}
18
19impl GitHubPRClient {
20    /// Create a new GitHub PR client
21    pub fn new(owner: String, repo: String, token: String, base_branch: String) -> Self {
22        Self {
23            owner,
24            repo,
25            token,
26            base_branch,
27            client: Client::new(),
28        }
29    }
30
31    /// Create a pull request
32    pub async fn create_pr(&self, request: PRRequest) -> Result<PRResult> {
33        // Step 1: Get base branch SHA
34        let base_sha = self.get_branch_sha(&self.base_branch).await?;
35
36        // Step 2: Create new branch
37        self.create_branch(&request.branch, &base_sha).await?;
38
39        // Step 3: Create commits for file changes
40        let mut current_sha = base_sha;
41        for file_change in &request.files {
42            current_sha = match file_change.change_type {
43                PRFileChangeType::Create | PRFileChangeType::Update => {
44                    self.create_file_commit(&request.branch, file_change, &current_sha).await?
45                }
46                PRFileChangeType::Delete => {
47                    self.delete_file_commit(&request.branch, file_change, &current_sha).await?
48                }
49            };
50        }
51
52        // Step 4: Create pull request
53        let pr = self.create_pull_request(&request, &current_sha).await?;
54
55        // Step 5: Add labels if any
56        if !request.labels.is_empty() {
57            self.add_labels(pr.number, &request.labels).await?;
58        }
59
60        // Step 6: Request reviewers if any
61        if !request.reviewers.is_empty() {
62            self.request_reviewers(pr.number, &request.reviewers).await?;
63        }
64
65        Ok(pr)
66    }
67
68    async fn get_branch_sha(&self, branch: &str) -> Result<String> {
69        let url = format!(
70            "https://api.github.com/repos/{}/{}/git/ref/heads/{}",
71            self.owner, self.repo, branch
72        );
73
74        let response = self
75            .client
76            .get(&url)
77            .header("Authorization", format!("Bearer {}", self.token))
78            .header("Accept", "application/vnd.github.v3+json")
79            .send()
80            .await
81            .map_err(|e| Error::internal(format!("Failed to get branch: {}", e)))?;
82
83        if !response.status().is_success() {
84            return Err(Error::internal(format!("Failed to get branch: {}", response.status())));
85        }
86
87        let json: serde_json::Value = response
88            .json()
89            .await
90            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
91
92        json["object"]["sha"]
93            .as_str()
94            .ok_or_else(|| Error::internal("Missing SHA in response"))?
95            .to_string()
96            .pipe(Ok)
97    }
98
99    async fn create_branch(&self, branch: &str, sha: &str) -> Result<()> {
100        let url = format!("https://api.github.com/repos/{}/{}/git/refs", self.owner, self.repo);
101
102        let body = serde_json::json!({
103            "ref": format!("refs/heads/{}", branch),
104            "sha": sha
105        });
106
107        let response = self
108            .client
109            .post(&url)
110            .header("Authorization", format!("Bearer {}", self.token))
111            .header("Accept", "application/vnd.github.v3+json")
112            .json(&body)
113            .send()
114            .await
115            .map_err(|e| Error::internal(format!("Failed to create branch: {}", e)))?;
116
117        let status = response.status();
118        if !status.is_success() {
119            let error_text = response.text().await.unwrap_or_default();
120            return Err(Error::internal(format!(
121                "Failed to create branch: {} - {}",
122                status, error_text
123            )));
124        }
125
126        Ok(())
127    }
128
129    async fn create_file_commit(
130        &self,
131        branch: &str,
132        file_change: &PRFileChange,
133        parent_sha: &str,
134    ) -> Result<String> {
135        // First, create blob with file content
136        let blob_sha = self.create_blob(&file_change.content).await?;
137
138        // Then, create tree with the new file
139        let tree_sha = self.create_tree(parent_sha, &file_change.path, &blob_sha, "100644").await?;
140
141        // Finally, create commit
142        let commit_sha = self
143            .create_commit(parent_sha, &tree_sha, &format!("Update {}", file_change.path))
144            .await?;
145
146        // Update branch reference
147        self.update_branch_ref(branch, &commit_sha).await?;
148
149        Ok(commit_sha)
150    }
151
152    async fn delete_file_commit(
153        &self,
154        branch: &str,
155        file_change: &PRFileChange,
156        parent_sha: &str,
157    ) -> Result<String> {
158        // Create tree without the file
159        let tree_sha = self.create_tree_delete(parent_sha, &file_change.path).await?;
160
161        // Create commit
162        let commit_sha = self
163            .create_commit(parent_sha, &tree_sha, &format!("Delete {}", file_change.path))
164            .await?;
165
166        // Update branch reference
167        self.update_branch_ref(branch, &commit_sha).await?;
168
169        Ok(commit_sha)
170    }
171
172    async fn create_blob(&self, content: &str) -> Result<String> {
173        let url = format!("https://api.github.com/repos/{}/{}/git/blobs", self.owner, self.repo);
174
175        let body = serde_json::json!({
176            "content": content,
177            "encoding": "utf-8"
178        });
179
180        let response = self
181            .client
182            .post(&url)
183            .header("Authorization", format!("Bearer {}", self.token))
184            .header("Accept", "application/vnd.github.v3+json")
185            .json(&body)
186            .send()
187            .await
188            .map_err(|e| Error::internal(format!("Failed to create blob: {}", e)))?;
189
190        if !response.status().is_success() {
191            return Err(Error::internal(format!("Failed to create blob: {}", response.status())));
192        }
193
194        let json: serde_json::Value = response
195            .json()
196            .await
197            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
198
199        json["sha"]
200            .as_str()
201            .ok_or_else(|| Error::internal("Missing SHA in response"))?
202            .to_string()
203            .pipe(Ok)
204    }
205
206    async fn create_tree(
207        &self,
208        base_tree_sha: &str,
209        path: &str,
210        blob_sha: &str,
211        mode: &str,
212    ) -> Result<String> {
213        let url = format!("https://api.github.com/repos/{}/{}/git/trees", self.owner, self.repo);
214
215        let body = serde_json::json!({
216            "base_tree": base_tree_sha,
217            "tree": [{
218                "path": path,
219                "mode": mode,
220                "type": "blob",
221                "sha": blob_sha
222            }]
223        });
224
225        let response = self
226            .client
227            .post(&url)
228            .header("Authorization", format!("Bearer {}", self.token))
229            .header("Accept", "application/vnd.github.v3+json")
230            .json(&body)
231            .send()
232            .await
233            .map_err(|e| Error::internal(format!("Failed to create tree: {}", e)))?;
234
235        if !response.status().is_success() {
236            return Err(Error::internal(format!("Failed to create tree: {}", response.status())));
237        }
238
239        let json: serde_json::Value = response
240            .json()
241            .await
242            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
243
244        json["sha"]
245            .as_str()
246            .ok_or_else(|| Error::internal("Missing SHA in response"))?
247            .to_string()
248            .pipe(Ok)
249    }
250
251    async fn create_tree_delete(&self, base_tree_sha: &str, path: &str) -> Result<String> {
252        let url = format!("https://api.github.com/repos/{}/{}/git/trees", self.owner, self.repo);
253
254        let body = serde_json::json!({
255            "base_tree": base_tree_sha,
256            "tree": [{
257                "path": path,
258                "mode": "100644",
259                "type": "blob",
260                "sha": null
261            }]
262        });
263
264        let response = self
265            .client
266            .post(&url)
267            .header("Authorization", format!("Bearer {}", self.token))
268            .header("Accept", "application/vnd.github.v3+json")
269            .json(&body)
270            .send()
271            .await
272            .map_err(|e| Error::internal(format!("Failed to create tree: {}", e)))?;
273
274        if !response.status().is_success() {
275            return Err(Error::internal(format!("Failed to create tree: {}", response.status())));
276        }
277
278        let json: serde_json::Value = response
279            .json()
280            .await
281            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
282
283        json["sha"]
284            .as_str()
285            .ok_or_else(|| Error::internal("Missing SHA in response"))?
286            .to_string()
287            .pipe(Ok)
288    }
289
290    async fn create_commit(
291        &self,
292        parent_sha: &str,
293        tree_sha: &str,
294        message: &str,
295    ) -> Result<String> {
296        let url = format!("https://api.github.com/repos/{}/{}/git/commits", self.owner, self.repo);
297
298        let body = serde_json::json!({
299            "message": message,
300            "tree": tree_sha,
301            "parents": [parent_sha]
302        });
303
304        let response = self
305            .client
306            .post(&url)
307            .header("Authorization", format!("Bearer {}", self.token))
308            .header("Accept", "application/vnd.github.v3+json")
309            .json(&body)
310            .send()
311            .await
312            .map_err(|e| Error::internal(format!("Failed to create commit: {}", e)))?;
313
314        if !response.status().is_success() {
315            return Err(Error::internal(format!("Failed to create commit: {}", response.status())));
316        }
317
318        let json: serde_json::Value = response
319            .json()
320            .await
321            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
322
323        json["sha"]
324            .as_str()
325            .ok_or_else(|| Error::internal("Missing SHA in response"))?
326            .to_string()
327            .pipe(Ok)
328    }
329
330    async fn update_branch_ref(&self, branch: &str, sha: &str) -> Result<()> {
331        let url = format!(
332            "https://api.github.com/repos/{}/{}/git/refs/heads/{}",
333            self.owner, self.repo, branch
334        );
335
336        let body = serde_json::json!({
337            "sha": sha,
338            "force": false
339        });
340
341        let response = self
342            .client
343            .patch(&url)
344            .header("Authorization", format!("Bearer {}", self.token))
345            .header("Accept", "application/vnd.github.v3+json")
346            .json(&body)
347            .send()
348            .await
349            .map_err(|e| Error::internal(format!("Failed to update branch: {}", e)))?;
350
351        if !response.status().is_success() {
352            return Err(Error::internal(format!("Failed to update branch: {}", response.status())));
353        }
354
355        Ok(())
356    }
357
358    async fn create_pull_request(&self, request: &PRRequest, _head_sha: &str) -> Result<PRResult> {
359        let url = format!("https://api.github.com/repos/{}/{}/pulls", self.owner, self.repo);
360
361        let body = serde_json::json!({
362            "title": request.title,
363            "body": request.body,
364            "head": request.branch,
365            "base": self.base_branch
366        });
367
368        let response = self
369            .client
370            .post(&url)
371            .header("Authorization", format!("Bearer {}", self.token))
372            .header("Accept", "application/vnd.github.v3+json")
373            .json(&body)
374            .send()
375            .await
376            .map_err(|e| Error::internal(format!("Failed to create PR: {}", e)))?;
377
378        let status = response.status();
379        if !status.is_success() {
380            let error_text = response.text().await.unwrap_or_default();
381            return Err(Error::internal(format!(
382                "Failed to create PR: {} - {}",
383                status, error_text
384            )));
385        }
386
387        let json: serde_json::Value = response
388            .json()
389            .await
390            .map_err(|e| Error::internal(format!("Failed to parse response: {}", e)))?;
391
392        Ok(PRResult {
393            number: json["number"].as_u64().ok_or_else(|| Error::internal("Missing PR number"))?,
394            url: json["html_url"]
395                .as_str()
396                .ok_or_else(|| Error::internal("Missing PR URL"))?
397                .to_string(),
398            branch: request.branch.clone(),
399            title: request.title.clone(),
400        })
401    }
402
403    async fn add_labels(&self, pr_number: u64, labels: &[String]) -> Result<()> {
404        let url = format!(
405            "https://api.github.com/repos/{}/{}/issues/{}/labels",
406            self.owner, self.repo, pr_number
407        );
408
409        let body = serde_json::json!({
410            "labels": labels
411        });
412
413        let response = self
414            .client
415            .post(&url)
416            .header("Authorization", format!("Bearer {}", self.token))
417            .header("Accept", "application/vnd.github.v3+json")
418            .json(&body)
419            .send()
420            .await
421            .map_err(|e| Error::internal(format!("Failed to add labels: {}", e)))?;
422
423        if !response.status().is_success() {
424            return Err(Error::internal(format!("Failed to add labels: {}", response.status())));
425        }
426
427        Ok(())
428    }
429
430    async fn request_reviewers(&self, pr_number: u64, reviewers: &[String]) -> Result<()> {
431        let url = format!(
432            "https://api.github.com/repos/{}/{}/pulls/{}/requested_reviewers",
433            self.owner, self.repo, pr_number
434        );
435
436        let body = serde_json::json!({
437            "reviewers": reviewers
438        });
439
440        let response = self
441            .client
442            .post(&url)
443            .header("Authorization", format!("Bearer {}", self.token))
444            .header("Accept", "application/vnd.github.v3+json")
445            .json(&body)
446            .send()
447            .await
448            .map_err(|e| Error::internal(format!("Failed to request reviewers: {}", e)))?;
449
450        if !response.status().is_success() {
451            return Err(Error::internal(format!(
452                "Failed to request reviewers: {}",
453                response.status()
454            )));
455        }
456
457        Ok(())
458    }
459}
460
461// Helper trait for pipe operator
462trait Pipe: Sized {
463    fn pipe<F, R>(self, f: F) -> R
464    where
465        F: FnOnce(Self) -> R,
466    {
467        f(self)
468    }
469}
470
471impl<T> Pipe for T {}