Skip to main content

git_iris/
github.rs

1use crate::git::GitRepo;
2use anyhow::{Context, Result, anyhow, bail};
3use octocrab::models::pulls::{PullRequest, Review, ReviewAction};
4use octocrab::{Octocrab, params};
5use regex::Regex;
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::LazyLock;
11use url::Url;
12
13static BACKTICK_LOCATION_RE: LazyLock<Regex> = LazyLock::new(|| {
14    Regex::new(r"`([^`\s]+):(\d+)`").expect("backtick location regex should compile")
15});
16static PLAIN_LOCATION_RE: LazyLock<Regex> = LazyLock::new(|| {
17    Regex::new(r"([A-Za-z0-9_./-]+\.[A-Za-z0-9_-]+):(\d+)")
18        .expect("plain location regex should compile")
19});
20static HUNK_RE: LazyLock<Regex> = LazyLock::new(|| {
21    Regex::new(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,\d+)? @@")
22        .expect("unified diff hunk regex should compile")
23});
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct GitHubRepository {
27    pub owner: String,
28    pub name: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct PullRequestTemplate {
33    pub path: String,
34    pub body: String,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct ReviewPublishOptions {
39    pub event: ReviewAction,
40    pub inline_comments: bool,
41}
42
43pub struct GitHubClient {
44    crab: Octocrab,
45    repo: GitHubRepository,
46}
47
48impl GitHubClient {
49    pub fn from_git_repo(repo: &GitRepo) -> Result<Self> {
50        let remote_url = github_remote_url(repo)?;
51        let github_repo = GitHubRepository::parse(&remote_url)?;
52        let token =
53            gh_token::get().map_err(|e| anyhow!("GitHub authentication unavailable: {e}"))?;
54        let crab = Octocrab::builder()
55            .personal_token(token)
56            .build()
57            .context("Failed to initialize GitHub client")?;
58
59        Ok(Self {
60            crab,
61            repo: github_repo,
62        })
63    }
64
65    pub async fn resolve_pull_number(
66        &self,
67        explicit_pull_number: Option<u64>,
68        git_repo: &GitRepo,
69    ) -> Result<u64> {
70        if let Some(number) = explicit_pull_number {
71            return Ok(number);
72        }
73
74        let branch = git_repo
75            .get_current_branch()
76            .context("Could not infer PR: failed to read current branch")?;
77        if branch == "HEAD detached" {
78            bail!("Could not infer PR from a detached HEAD; pass --pr <number>");
79        }
80
81        self.find_open_pull_for_branch(&branch).await
82    }
83
84    pub async fn update_pull_body(&self, pull_number: u64, body: &str) -> Result<PullRequest> {
85        self.crab
86            .pulls(&self.repo.owner, &self.repo.name)
87            .update(pull_number)
88            .body(body)
89            .send()
90            .await
91            .with_context(|| format!("Failed to update PR #{pull_number}"))
92    }
93
94    pub async fn pull_body(&self, pull_number: u64) -> Result<String> {
95        let pull = self
96            .crab
97            .pulls(&self.repo.owner, &self.repo.name)
98            .get(pull_number)
99            .await
100            .with_context(|| format!("Failed to fetch PR #{pull_number}"))?;
101
102        Ok(pull.body.unwrap_or_default())
103    }
104
105    pub async fn publish_review(
106        &self,
107        pull_number: u64,
108        body: &str,
109        options: ReviewPublishOptions,
110    ) -> Result<Review> {
111        let pull = self
112            .crab
113            .pulls(&self.repo.owner, &self.repo.name)
114            .get(pull_number)
115            .await
116            .with_context(|| format!("Failed to fetch PR #{pull_number}"))?;
117        let comments = if options.inline_comments {
118            self.validated_inline_comments(pull_number, body).await?
119        } else {
120            Vec::new()
121        };
122
123        let route = format!(
124            "/repos/{owner}/{repo}/pulls/{pull_number}/reviews",
125            owner = self.repo.owner,
126            repo = self.repo.name,
127        );
128        let payload = serde_json::json!({
129            "body": body,
130            "event": options.event,
131            "commit_id": pull.head.sha,
132            "comments": comments,
133        });
134
135        self.crab
136            .post(route, Some(&payload))
137            .await
138            .with_context(|| format!("Failed to publish review on PR #{pull_number}"))
139    }
140
141    pub fn repo(&self) -> &GitHubRepository {
142        &self.repo
143    }
144
145    async fn find_open_pull_for_branch(&self, branch: &str) -> Result<u64> {
146        let same_repo_head = format!("{}:{branch}", self.repo.owner);
147        let page = self
148            .crab
149            .pulls(&self.repo.owner, &self.repo.name)
150            .list()
151            .state(params::State::Open)
152            .head(same_repo_head)
153            .per_page(10)
154            .send()
155            .await
156            .with_context(|| format!("Failed to search open PRs for branch `{branch}`"))?;
157
158        if let Some(number) = single_pull_number(&page.items) {
159            return Ok(number);
160        }
161
162        let page = self
163            .crab
164            .pulls(&self.repo.owner, &self.repo.name)
165            .list()
166            .state(params::State::Open)
167            .per_page(100)
168            .send()
169            .await
170            .context("Failed to list open PRs")?;
171        let matches: Vec<&PullRequest> = page
172            .items
173            .iter()
174            .filter(|pull| pull.head.ref_field == branch)
175            .collect();
176
177        match matches.as_slice() {
178            [pull] => Ok(pull.number),
179            [] => bail!("No open GitHub PR found for branch `{branch}`; pass --pr <number>"),
180            _ => bail!("Multiple open GitHub PRs found for branch `{branch}`; pass --pr <number>"),
181        }
182    }
183
184    async fn validated_inline_comments(
185        &self,
186        pull_number: u64,
187        review: &str,
188    ) -> Result<Vec<Value>> {
189        let diff = self
190            .crab
191            .pulls(&self.repo.owner, &self.repo.name)
192            .get_diff(pull_number)
193            .await
194            .with_context(|| format!("Failed to fetch PR #{pull_number} diff"))?;
195        let reviewable_lines = parse_reviewable_lines(&diff);
196
197        Ok(extract_inline_comment_candidates(review)
198            .into_iter()
199            .filter(|candidate| {
200                reviewable_lines
201                    .get(&candidate.path)
202                    .is_some_and(|lines| lines.contains(&candidate.line))
203            })
204            .map(|candidate| {
205                serde_json::json!({
206                    "path": candidate.path,
207                    "line": candidate.line,
208                    "side": "RIGHT",
209                    "body": candidate.body,
210                })
211            })
212            .collect())
213    }
214}
215
216pub fn find_pull_request_template(repo_root: &Path) -> Result<Option<PullRequestTemplate>> {
217    for path in singular_template_paths(repo_root) {
218        if path.is_file() {
219            return read_template(repo_root, &path).map(Some);
220        }
221    }
222
223    for dir in template_directories(repo_root) {
224        if let Some(template) = directory_template(repo_root, &dir)? {
225            return Ok(Some(template));
226        }
227    }
228
229    Ok(None)
230}
231
232fn singular_template_paths(repo_root: &Path) -> [PathBuf; 3] {
233    [
234        repo_root.join(".github/pull_request_template.md"),
235        repo_root.join("pull_request_template.md"),
236        repo_root.join("docs/pull_request_template.md"),
237    ]
238}
239
240fn template_directories(repo_root: &Path) -> [PathBuf; 3] {
241    [
242        repo_root.join(".github/PULL_REQUEST_TEMPLATE"),
243        repo_root.join("PULL_REQUEST_TEMPLATE"),
244        repo_root.join("docs/PULL_REQUEST_TEMPLATE"),
245    ]
246}
247
248fn directory_template(repo_root: &Path, dir: &Path) -> Result<Option<PullRequestTemplate>> {
249    if !dir.is_dir() {
250        return Ok(None);
251    }
252
253    let default_path = dir.join("pull_request_template.md");
254    if default_path.is_file() {
255        return read_template(repo_root, &default_path).map(Some);
256    }
257
258    let markdown_templates = markdown_files(dir)?;
259    if markdown_templates.len() == 1 {
260        read_template(repo_root, &markdown_templates[0]).map(Some)
261    } else {
262        Ok(None)
263    }
264}
265
266fn markdown_files(dir: &Path) -> Result<Vec<PathBuf>> {
267    let mut files = Vec::new();
268    for entry in fs::read_dir(dir).with_context(|| format!("Failed to read {}", dir.display()))? {
269        let path = entry?.path();
270        if path.is_file()
271            && path
272                .extension()
273                .and_then(|extension| extension.to_str())
274                .is_some_and(|extension| extension.eq_ignore_ascii_case("md"))
275        {
276            files.push(path);
277        }
278    }
279    files.sort();
280    Ok(files)
281}
282
283fn read_template(repo_root: &Path, path: &Path) -> Result<PullRequestTemplate> {
284    let body = fs::read_to_string(path)
285        .with_context(|| format!("Failed to read PR template {}", path.display()))?;
286    let relative_path = path
287        .strip_prefix(repo_root)
288        .unwrap_or(path)
289        .to_string_lossy()
290        .to_string();
291
292    Ok(PullRequestTemplate {
293        path: relative_path,
294        body,
295    })
296}
297
298impl GitHubRepository {
299    pub fn parse(remote_url: &str) -> Result<Self> {
300        if let Some(path) = remote_url.strip_prefix("git@github.com:") {
301            return Self::parse_path(path);
302        }
303
304        let url = Url::parse(remote_url)
305            .with_context(|| format!("Could not parse GitHub remote URL `{remote_url}`"))?;
306        if url.host_str() != Some("github.com") {
307            bail!("Only github.com remotes are supported for GitHub publishing");
308        }
309        Self::parse_path(url.path().trim_start_matches('/'))
310    }
311
312    fn parse_path(path: &str) -> Result<Self> {
313        let clean_path = path.trim_end_matches(".git").trim_end_matches('/');
314        let mut parts = clean_path.split('/');
315        let owner = parts
316            .next()
317            .filter(|part| !part.is_empty())
318            .ok_or_else(|| anyhow!("GitHub remote URL is missing an owner"))?;
319        let name = parts
320            .next()
321            .filter(|part| !part.is_empty())
322            .ok_or_else(|| anyhow!("GitHub remote URL is missing a repository name"))?;
323
324        if parts.next().is_some() {
325            bail!("GitHub remote URL has an unexpected path shape");
326        }
327
328        Ok(Self {
329            owner: owner.to_string(),
330            name: name.to_string(),
331        })
332    }
333}
334
335fn github_remote_url(repo: &GitRepo) -> Result<String> {
336    if let Some(url) = repo.get_remote_url() {
337        return Ok(url.to_string());
338    }
339
340    let raw_repo = repo.open_repo()?;
341    let remote = raw_repo
342        .find_remote("origin")
343        .or_else(|_| {
344            let remotes = raw_repo.remotes()?;
345            let remote_name = remotes
346                .iter()
347                .flatten()
348                .next()
349                .ok_or(git2::Error::from_str("No git remotes configured"))?;
350            raw_repo.find_remote(remote_name)
351        })
352        .context("Could not find a git remote for GitHub publishing")?;
353
354    remote
355        .url()
356        .map(std::string::ToString::to_string)
357        .ok_or_else(|| anyhow!("Git remote has no URL"))
358}
359
360fn single_pull_number(pulls: &[PullRequest]) -> Option<u64> {
361    match pulls {
362        [pull] => Some(pull.number),
363        _ => None,
364    }
365}
366
367#[derive(Debug, Clone, PartialEq, Eq)]
368struct InlineCommentCandidate {
369    path: String,
370    line: u64,
371    body: String,
372}
373
374fn extract_inline_comment_candidates(review: &str) -> Vec<InlineCommentCandidate> {
375    let lines: Vec<&str> = review.lines().collect();
376    let mut candidates = Vec::new();
377    let mut index = 0;
378
379    while index < lines.len() {
380        let line = lines[index];
381        if let Some((path, line_number)) = extract_location(line)
382            && looks_like_finding(line)
383        {
384            let body = finding_body(&lines, index);
385            candidates.push(InlineCommentCandidate {
386                path,
387                line: line_number,
388                body,
389            });
390        }
391        index += 1;
392    }
393
394    candidates
395}
396
397fn extract_location(line: &str) -> Option<(String, u64)> {
398    BACKTICK_LOCATION_RE
399        .captures(line)
400        .or_else(|| PLAIN_LOCATION_RE.captures(line))
401        .and_then(|captures| {
402            let path = captures.get(1)?.as_str().to_string();
403            let line = captures.get(2)?.as_str().parse().ok()?;
404            Some((path, line))
405        })
406}
407
408fn looks_like_finding(line: &str) -> bool {
409    ["[CRITICAL]", "[HIGH]", "[MEDIUM]", "[LOW]"]
410        .iter()
411        .any(|severity| line.contains(severity))
412}
413
414fn finding_body(lines: &[&str], start: usize) -> String {
415    let mut body = Vec::new();
416    let mut index = start;
417
418    while index < lines.len() {
419        let line = lines[index];
420        if index > start && starts_new_finding_or_section(line) {
421            break;
422        }
423        body.push(line.trim());
424        index += 1;
425    }
426
427    body.join("\n").trim().to_string()
428}
429
430fn starts_new_finding_or_section(line: &str) -> bool {
431    let trimmed = line.trim_start();
432    trimmed.starts_with("# ") || trimmed.starts_with("## ") || looks_like_finding(trimmed)
433}
434
435fn parse_reviewable_lines(diff: &str) -> HashMap<String, HashSet<u64>> {
436    let mut lines_by_path = HashMap::new();
437    let mut current_path: Option<String> = None;
438    let mut new_line: Option<u64> = None;
439
440    for line in diff.lines() {
441        if let Some(path) = line.strip_prefix("+++ b/") {
442            current_path = Some(path.to_string());
443            continue;
444        }
445
446        if let Some(captures) = HUNK_RE.captures(line) {
447            new_line = captures.get(1).and_then(|m| m.as_str().parse().ok());
448            continue;
449        }
450
451        let Some(path) = current_path.as_ref() else {
452            continue;
453        };
454        let Some(line_number) = new_line else {
455            continue;
456        };
457
458        if let Some(b'+' | b' ') = line.as_bytes().first().copied() {
459            lines_by_path
460                .entry(path.clone())
461                .or_insert_with(HashSet::new)
462                .insert(line_number);
463            new_line = Some(line_number + 1);
464        }
465    }
466
467    lines_by_path
468}
469
470#[cfg(test)]
471mod tests;