Skip to main content

libverify_github/
graphql.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result, bail};
4use serde::Deserialize;
5
6use crate::client::GitHubClient;
7use crate::types::*;
8
9// -- Batch sizes --
10
11const PR_BATCH_SIZE: usize = 20;
12const COMMIT_PR_BATCH_SIZE: usize = 100;
13
14// -- GraphQL response wrapper --
15
16#[derive(Deserialize)]
17struct GqlResponse {
18    data: Option<serde_json::Value>,
19    errors: Option<Vec<GqlError>>,
20}
21
22#[derive(Deserialize)]
23struct GqlError {
24    message: String,
25}
26
27// -- PR data types (GraphQL-specific) --
28
29#[derive(Deserialize)]
30struct GqlPullRequest {
31    number: u32,
32    title: String,
33    body: Option<String>,
34    author: Option<GqlActor>,
35    #[serde(rename = "headRefOid")]
36    head_ref_oid: String,
37    #[serde(rename = "baseRefName")]
38    base_ref_name: String,
39    files: Option<GqlConnection<GqlChangedFile>>,
40    reviews: Option<GqlConnection<GqlReview>>,
41    commits: Option<GqlConnection<GqlPrCommitNode>>,
42    #[serde(rename = "statusChecks")]
43    status_checks: Option<GqlConnection<GqlStatusCheckNode>>,
44}
45
46#[derive(Deserialize)]
47struct GqlActor {
48    login: String,
49}
50
51#[derive(Deserialize)]
52struct GqlConnection<T> {
53    nodes: Vec<T>,
54}
55
56#[derive(Deserialize)]
57struct GqlChangedFile {
58    path: String,
59    additions: u32,
60    deletions: u32,
61    #[serde(rename = "changeType")]
62    change_type: String,
63}
64
65#[derive(Deserialize)]
66struct GqlReview {
67    author: Option<GqlActor>,
68    state: String,
69    #[serde(rename = "submittedAt")]
70    submitted_at: Option<String>,
71    body: Option<String>,
72}
73
74#[derive(Deserialize)]
75struct GqlPrCommitNode {
76    commit: GqlPrCommit,
77}
78
79#[derive(Deserialize)]
80struct GqlPrCommit {
81    oid: String,
82    author: Option<GqlCommitAuthor>,
83    committer: Option<GqlCommitDate>,
84    signature: Option<GqlSignature>,
85}
86
87#[derive(Deserialize)]
88struct GqlCommitAuthor {
89    user: Option<GqlActor>,
90}
91
92#[derive(Deserialize)]
93struct GqlCommitDate {
94    date: Option<String>,
95}
96
97#[derive(Deserialize)]
98struct GqlSignature {
99    #[serde(rename = "isValid")]
100    is_valid: bool,
101    state: String,
102}
103
104#[derive(Deserialize)]
105struct GqlStatusCheckNode {
106    commit: GqlStatusCheckCommit,
107}
108
109#[derive(Deserialize)]
110struct GqlStatusCheckCommit {
111    #[serde(rename = "statusCheckRollup")]
112    status_check_rollup: Option<GqlStatusCheckRollup>,
113}
114
115#[derive(Deserialize)]
116struct GqlStatusCheckRollup {
117    contexts: GqlConnection<serde_json::Value>,
118}
119
120// -- Commit -> PR resolution types --
121
122#[derive(Deserialize)]
123struct GqlCommitWithPrs {
124    #[serde(rename = "associatedPullRequests")]
125    associated_pull_requests: GqlConnection<GqlAssociatedPr>,
126}
127
128#[derive(Deserialize)]
129struct GqlAssociatedPr {
130    number: u32,
131    #[serde(rename = "mergedAt")]
132    merged_at: Option<String>,
133    author: Option<GqlActor>,
134}
135
136// -- Public data type --
137
138pub struct PrData {
139    pub metadata: PrMetadata,
140    pub files: Vec<PrFile>,
141    pub reviews: Vec<Review>,
142    pub commits: Vec<PrCommit>,
143    pub check_runs: Vec<CheckRunItem>,
144    pub commit_statuses: Vec<CommitStatusItem>,
145}
146
147// -- Query fragment --
148
149fn pr_fields_fragment() -> &'static str {
150    r#"fragment PrFields on PullRequest {
151  number title body
152  author { login }
153  headRefOid baseRefName
154  files(first: 100) {
155    nodes { path additions deletions changeType }
156  }
157  reviews(first: 100) {
158    nodes { author { login } state submittedAt body }
159  }
160  commits(first: 250) {
161    nodes {
162      commit {
163        oid
164        author { user { login } }
165        committer { date }
166        signature { isValid state }
167      }
168    }
169  }
170  statusChecks: commits(last: 1) {
171    nodes {
172      commit {
173        statusCheckRollup {
174          contexts(first: 100) {
175            nodes {
176              __typename
177              ... on CheckRun {
178                name status conclusion
179                checkSuite { app { slug } }
180              }
181              ... on StatusContext {
182                context state
183              }
184            }
185          }
186        }
187      }
188    }
189  }
190}"#
191}
192
193// -- Query builders --
194
195fn single_pr_query(owner: &str, repo: &str, number: u32) -> String {
196    let fragment = pr_fields_fragment();
197    format!(
198        r#"query {{
199  repository(owner: "{owner}", name: "{repo}") {{
200    pullRequest(number: {number}) {{ ...PrFields }}
201  }}
202}}
203{fragment}"#
204    )
205}
206
207fn batch_pr_query(owner: &str, repo: &str, numbers: &[u32]) -> String {
208    let aliases: Vec<String> = numbers
209        .iter()
210        .enumerate()
211        .map(|(i, n)| format!("    pr{i}: pullRequest(number: {n}) {{ ...PrFields }}"))
212        .collect();
213
214    let aliases_str = aliases.join("\n");
215    let fragment = pr_fields_fragment();
216    format!(
217        r#"query {{
218  repository(owner: "{owner}", name: "{repo}") {{
219{aliases_str}
220  }}
221}}
222{fragment}"#
223    )
224}
225
226fn commit_prs_query(owner: &str, repo: &str, shas: &[&str]) -> String {
227    let aliases: Vec<String> = shas
228        .iter()
229        .enumerate()
230        .map(|(i, sha)| {
231            format!(
232                r#"    c{i}: object(expression: "{sha}") {{
233      ... on Commit {{
234        associatedPullRequests(first: 10) {{
235          nodes {{ number mergedAt author {{ login }} }}
236        }}
237      }}
238    }}"#
239            )
240        })
241        .collect();
242
243    format!(
244        r#"query {{
245  repository(owner: "{owner}", name: "{repo}") {{
246{aliases}
247  }}
248}}"#,
249        aliases = aliases.join("\n")
250    )
251}
252
253// -- Conversion functions --
254
255fn convert_pr(pr: GqlPullRequest) -> PrData {
256    let metadata = PrMetadata {
257        number: pr.number,
258        title: pr.title,
259        body: pr.body,
260        user: pr.author.map(|a| PrUser { login: a.login }),
261        head: PrHead {
262            sha: pr.head_ref_oid,
263        },
264        base: PrBase {
265            ref_name: pr.base_ref_name,
266        },
267    };
268
269    let files = pr
270        .files
271        .map(|f| f.nodes)
272        .unwrap_or_default()
273        .into_iter()
274        .map(|f| PrFile {
275            filename: f.path,
276            patch: None, // GraphQL does not expose patch content
277            additions: f.additions,
278            deletions: f.deletions,
279            status: convert_change_type(&f.change_type),
280        })
281        .collect();
282
283    let reviews = pr
284        .reviews
285        .map(|r| r.nodes)
286        .unwrap_or_default()
287        .into_iter()
288        .map(|r| Review {
289            user: PrUser {
290                login: r.author.map(|a| a.login).unwrap_or_default(),
291            },
292            state: r.state,
293            submitted_at: r.submitted_at,
294            body: r.body,
295        })
296        .collect();
297
298    let commits = pr
299        .commits
300        .map(|c| c.nodes)
301        .unwrap_or_default()
302        .into_iter()
303        .map(convert_commit_node)
304        .collect();
305
306    let (check_runs, commit_statuses) =
307        extract_status_checks(pr.status_checks.and_then(|sc| sc.nodes.into_iter().next()));
308
309    PrData {
310        metadata,
311        files,
312        reviews,
313        commits,
314        check_runs,
315        commit_statuses,
316    }
317}
318
319fn convert_commit_node(node: GqlPrCommitNode) -> PrCommit {
320    let c = node.commit;
321    PrCommit {
322        sha: c.oid,
323        commit: PrCommitInner {
324            committer: c.committer.map(|ct| PrCommitAuthor { date: ct.date }),
325            verification: Some(match c.signature {
326                Some(sig) => CommitVerification {
327                    verified: sig.is_valid,
328                    reason: sig.state.to_lowercase(),
329                },
330                None => CommitVerification {
331                    verified: false,
332                    reason: "unsigned".to_string(),
333                },
334            }),
335        },
336        author: c
337            .author
338            .and_then(|a| a.user)
339            .map(|u| PrUser { login: u.login }),
340    }
341}
342
343fn convert_change_type(change_type: &str) -> String {
344    match change_type {
345        "ADDED" => "added",
346        "DELETED" => "removed",
347        "MODIFIED" => "modified",
348        "RENAMED" => "renamed",
349        "COPIED" => "copied",
350        "CHANGED" => "changed",
351        _ => "modified",
352    }
353    .to_string()
354}
355
356fn extract_status_checks(
357    head_commit: Option<GqlStatusCheckNode>,
358) -> (Vec<CheckRunItem>, Vec<CommitStatusItem>) {
359    let mut check_runs = Vec::new();
360    let mut statuses = Vec::new();
361
362    let Some(node) = head_commit else {
363        return (check_runs, statuses);
364    };
365    let Some(rollup) = node.commit.status_check_rollup else {
366        return (check_runs, statuses);
367    };
368
369    for ctx in rollup.contexts.nodes {
370        let typename = ctx.get("__typename").and_then(|v| v.as_str()).unwrap_or("");
371        match typename {
372            "CheckRun" => {
373                check_runs.push(CheckRunItem {
374                    name: ctx
375                        .get("name")
376                        .and_then(|v| v.as_str())
377                        .unwrap_or("")
378                        .to_string(),
379                    status: ctx
380                        .get("status")
381                        .and_then(|v| v.as_str())
382                        .unwrap_or("")
383                        .to_lowercase(),
384                    conclusion: ctx
385                        .get("conclusion")
386                        .and_then(|v| v.as_str())
387                        .map(|s| s.to_lowercase()),
388                    app: ctx
389                        .get("checkSuite")
390                        .and_then(|cs| cs.get("app"))
391                        .and_then(|app| app.get("slug"))
392                        .and_then(|s| s.as_str())
393                        .map(|slug| CheckRunApp {
394                            slug: slug.to_string(),
395                        }),
396                });
397            }
398            "StatusContext" => {
399                statuses.push(CommitStatusItem {
400                    context: ctx
401                        .get("context")
402                        .and_then(|v| v.as_str())
403                        .unwrap_or("")
404                        .to_string(),
405                    state: ctx
406                        .get("state")
407                        .and_then(|v| v.as_str())
408                        .unwrap_or("")
409                        .to_lowercase(),
410                });
411            }
412            _ => {}
413        }
414    }
415
416    (check_runs, statuses)
417}
418
419fn check_errors(resp: &GqlResponse) -> Result<()> {
420    if let Some(errors) = &resp.errors
421        && resp.data.is_none()
422    {
423        let msgs: Vec<&str> = errors.iter().map(|e| e.message.as_str()).collect();
424        bail!("GraphQL errors: {}", msgs.join(", "));
425    }
426    Ok(())
427}
428
429// -- Public API --
430
431/// Fetch all data for a single PR in one GraphQL call.
432pub fn fetch_pr(client: &GitHubClient, owner: &str, repo: &str, number: u32) -> Result<PrData> {
433    let query = single_pr_query(owner, repo, number);
434    let body = client.post_graphql(&query, None)?;
435
436    let resp: GqlResponse =
437        serde_json::from_str(&body).context("failed to parse GraphQL response")?;
438    check_errors(&resp)?;
439
440    let data = resp.data.context("no data in GraphQL response")?;
441    let pr_value = data
442        .get("repository")
443        .and_then(|r| r.get("pullRequest"))
444        .context("pullRequest not found in response")?;
445
446    if pr_value.is_null() {
447        bail!("PR #{number} not found");
448    }
449
450    let pr: GqlPullRequest =
451        serde_json::from_value(pr_value.clone()).context("failed to deserialize pull request")?;
452
453    Ok(convert_pr(pr))
454}
455
456/// Fetch data for multiple PRs in batched GraphQL calls (up to 20 PRs per query).
457pub fn fetch_prs(
458    client: &GitHubClient,
459    owner: &str,
460    repo: &str,
461    numbers: &[u32],
462) -> Vec<(u32, Result<PrData>)> {
463    let mut results = Vec::new();
464
465    for chunk in numbers.chunks(PR_BATCH_SIZE) {
466        let query = batch_pr_query(owner, repo, chunk);
467        match client.post_graphql(&query, None) {
468            Err(e) => {
469                let msg = format!("{e:#}");
470                for &n in chunk {
471                    results.push((n, Err(anyhow::anyhow!("GraphQL batch failed: {msg}"))));
472                }
473                continue;
474            }
475            Ok(body) => {
476                let resp: GqlResponse = match serde_json::from_str(&body) {
477                    Ok(r) => r,
478                    Err(e) => {
479                        let msg = format!("{e:#}");
480                        for &n in chunk {
481                            results
482                                .push((n, Err(anyhow::anyhow!("failed to parse response: {msg}"))));
483                        }
484                        continue;
485                    }
486                };
487
488                if resp.data.is_none() {
489                    let msg = resp
490                        .errors
491                        .as_ref()
492                        .map(|errs| {
493                            errs.iter()
494                                .map(|e| e.message.as_str())
495                                .collect::<Vec<_>>()
496                                .join(", ")
497                        })
498                        .unwrap_or_else(|| "unknown error".to_string());
499                    for &n in chunk {
500                        results.push((n, Err(anyhow::anyhow!("GraphQL error: {msg}"))));
501                    }
502                    continue;
503                }
504
505                let data = resp.data.unwrap();
506                let repo_data = data.get("repository");
507
508                for (i, &number) in chunk.iter().enumerate() {
509                    let key = format!("pr{i}");
510                    let pr_result = repo_data
511                        .and_then(|r| r.get(&key))
512                        .map(|v| {
513                            if v.is_null() {
514                                Err(anyhow::anyhow!("PR #{number} not found"))
515                            } else {
516                                serde_json::from_value::<GqlPullRequest>(v.clone())
517                                    .map(convert_pr)
518                                    .context("failed to parse PR data")
519                            }
520                        })
521                        .unwrap_or_else(|| {
522                            Err(anyhow::anyhow!("PR #{number} missing from response"))
523                        });
524
525                    results.push((number, pr_result));
526                }
527            }
528        }
529    }
530
531    results
532}
533
534/// Resolve commits to their associated PRs in batched GraphQL calls.
535pub fn resolve_commit_prs(
536    client: &GitHubClient,
537    owner: &str,
538    repo: &str,
539    shas: &[&str],
540) -> Result<HashMap<String, Vec<PullRequestSummary>>> {
541    let mut result = HashMap::new();
542
543    for chunk in shas.chunks(COMMIT_PR_BATCH_SIZE) {
544        let query = commit_prs_query(owner, repo, chunk);
545        let body = client.post_graphql(&query, None)?;
546        let resp: GqlResponse =
547            serde_json::from_str(&body).context("failed to parse GraphQL response")?;
548        check_errors(&resp)?;
549
550        let data = resp.data.context("no data in GraphQL response")?;
551        let repo_data = data.get("repository").context("repository not found")?;
552
553        for (i, sha) in chunk.iter().enumerate() {
554            let key = format!("c{i}");
555            if let Some(obj) = repo_data.get(&key)
556                && !obj.is_null()
557                && let Ok(commit) = serde_json::from_value::<GqlCommitWithPrs>(obj.clone())
558            {
559                let prs = commit
560                    .associated_pull_requests
561                    .nodes
562                    .into_iter()
563                    .map(|pr| PullRequestSummary {
564                        number: pr.number,
565                        merged_at: pr.merged_at,
566                        user: PrUser {
567                            login: pr.author.map(|a| a.login).unwrap_or_default(),
568                        },
569                    })
570                    .collect();
571                result.insert(sha.to_string(), prs);
572            }
573        }
574    }
575
576    Ok(result)
577}