mod graphql;
use std::{borrow::Cow, collections::HashMap, path::Path};
use futures::{join, try_join};
use itertools::Itertools;
use reqwest::Method;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::{
description::FormatMergeRequest,
error::{ConfigSnafu, Error, GitHubApiSnafu, Result},
forge::{
ApprovalSatisfaction,
ApprovalStatus,
CheckStatus,
DiscussionCount,
Forge,
ForgeCreateMergeRequestOptions,
ForgeMergeRequest,
ForgeMergeRequestState,
ForgeUser,
MergeRequestStatus,
UserId,
},
};
pub struct GitHubForge {
base_url: String,
source_project_id: String,
target_project_id: String,
token: String,
client: reqwest::Client,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitHubUser {
pub id: u64,
pub login: String,
}
impl ForgeUser for GitHubUser {
fn id(&self) -> Option<Cow<'_, str>> {
Some(Cow::Owned(self.id.to_string()))
}
fn username(&self) -> Option<Cow<'_, str>> {
Some(Cow::Borrowed(&self.login))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchRef {
#[serde(rename = "ref")]
pub ref_name: String,
pub sha: String,
pub repo: Option<GitHubRepo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitHubRepo {
pub full_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullRequest {
pub number: u64,
pub id: u64,
pub title: String,
pub body: Option<String>,
pub head: BranchRef,
pub base: BranchRef,
pub state: String,
pub html_url: String,
pub user: GitHubUser,
pub created_at: String,
pub assignees: Vec<GitHubUser>,
pub requested_reviewers: Vec<GitHubUser>,
pub draft: bool,
#[serde(default)]
pub merged: bool,
}
impl ForgeMergeRequest for PullRequest {
type User = GitHubUser;
type Id = u64;
fn iid(&self) -> Self::Id {
self.number
}
fn title(&self) -> &str {
&self.title
}
fn description(&self) -> &str {
self.body.as_deref().unwrap_or_default()
}
fn source_branch(&self) -> &str {
&self.head.ref_name
}
fn target_branch(&self) -> &str {
&self.base.ref_name
}
fn state(&self) -> ForgeMergeRequestState {
if self.merged {
ForgeMergeRequestState::Merged
} else if self.state == "open" {
ForgeMergeRequestState::Open
} else {
ForgeMergeRequestState::Closed
}
}
fn url(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.html_url)
}
fn edit_url(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.html_url)
}
fn author_username(&self) -> &str {
&self.user.login
}
fn created_at(&self) -> jiff::Timestamp {
self.created_at
.parse()
.expect("Failed to parse creation date as ISO 8601")
}
fn assignees(&self) -> Vec<Self::User> {
self.assignees.clone()
}
fn reviewers(&self) -> Vec<Self::User> {
self.requested_reviewers.clone()
}
fn is_draft(&self) -> bool {
self.draft
}
fn clone_boxed(
&self,
) -> Box<dyn ForgeMergeRequest<User = Self::User, Id = Self::Id> + Send + Sync>
where
Self: Sync + Send,
{
Box::new(self.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ReviewState {
#[serde(rename = "APPROVED")]
Approved,
#[serde(rename = "CHANGES_REQUESTED")]
ChangesRequested,
#[serde(rename = "COMMENTED")]
Commented,
#[serde(rename = "DISMISSED")]
Dismissed,
#[serde(rename = "PENDING")]
Pending,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Review {
pub id: u64,
pub user: GitHubUser,
pub body: Option<String>,
pub state: ReviewState,
pub html_url: String,
pub submitted_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
enum BranchRuleType {
PullRequest,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BranchRule {
#[serde(rename = "type")]
pub rule_type: BranchRuleType,
#[serde(default)]
pub parameters: Option<BranchRuleParameters>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BranchRuleParameters {
#[serde(default)]
pub required_approving_review_count: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CheckRunsResponse {
pub total_count: u32,
pub check_runs: Vec<CheckRun>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum CheckRunStatus {
Queued,
InProgress,
Completed,
Waiting,
Requested,
Pending,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum CheckRunConclusion {
Success,
Failure,
Neutral,
Canceled,
Skipped,
TimedOut,
ActionRequired,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CheckRun {
pub id: u64,
pub name: String,
pub status: CheckRunStatus,
pub conclusion: Option<CheckRunConclusion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GraphQLResponse<T> {
data: T,
errors: Option<Vec<GraphQLError>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GraphQLError {
message: String,
}
impl GitHubForge {
pub fn new(
base_url: impl Into<String>,
source_project_id: impl Into<String>,
target_project_id: impl Into<String>,
token: impl Into<String>,
ca_bundle: Option<impl AsRef<Path>>,
accept_non_compliant_certs: bool,
) -> Result<Self> {
let mut client_builder = reqwest::Client::builder();
if accept_non_compliant_certs {
client_builder = client_builder.tls_danger_accept_invalid_certs(true);
}
if let Some(ca_path) = ca_bundle {
let ca_cert = std::fs::read(ca_path.as_ref()).map_err(|e| {
ConfigSnafu {
message: format!(
"Failed to read CA bundle at {}: {}",
ca_path.as_ref().to_string_lossy(),
e
),
}
.build()
})?;
let certs = reqwest::Certificate::from_pem_bundle(&ca_cert).map_err(|e| {
ConfigSnafu {
message: format!("Failed to parse CA bundle: {}", e),
}
.build()
})?;
for cert in certs {
client_builder = client_builder.add_root_certificate(cert);
}
}
let client = client_builder.build().map_err(|e| {
ConfigSnafu {
message: format!("Failed to build HTTP client: {}", e),
}
.build()
})?;
let base_url = base_url.into().trim_end_matches('/').to_string();
Ok(Self {
base_url,
source_project_id: source_project_id.into(),
target_project_id: target_project_id.into(),
token: token.into(),
client,
})
}
async fn request<T: DeserializeOwned>(
&self,
method: Method,
path: impl AsRef<str>,
payload: Option<impl Serialize>,
) -> Result<T> {
let mut req = self
.client
.request(method, format!("{}{}", self.base_url, path.as_ref()))
.header("Authorization", format!("token {}", &self.token))
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.header("User-Agent", "jj-vine");
if let Some(payload) = payload.as_ref() {
req = req.json(payload);
}
let response = req.send().await?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await?;
return Err(GitHubApiSnafu {
message: format!("Failed to get: {} - {}", status, text),
}
.build());
}
let body = response.text().await?;
let data: T = serde_json::from_str(&body).map_err(|e| {
GitHubApiSnafu {
message: format!(
"Failed to parse GET response to {}: {}, response: {}",
path.as_ref(),
e,
body
),
}
.build()
})?;
Ok(data)
}
async fn graphql<T: DeserializeOwned>(
&self,
query: &str,
variables: serde_json::Value,
) -> Result<T> {
let graphql_url = if self.base_url.starts_with("https://api.github.com") {
"https://api.github.com/graphql".to_string()
} else if self.base_url.contains("/api/v3") {
self.base_url.replace("/api/v3", "/api/graphql")
} else {
format!("{}/graphql", self.base_url)
};
let payload = serde_json::json!({
"query": query,
"variables": variables,
});
let response = self
.client
.post(&graphql_url)
.header("Authorization", format!("Bearer {}", &self.token))
.header("Accept", "application/vnd.github+json")
.header("User-Agent", "jj-vine")
.json(&payload)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await?;
return Err(GitHubApiSnafu {
message: format!("GraphQL request failed: {} - {}", status, text),
}
.build());
}
let body = response.text().await?;
let data: GraphQLResponse<T> = serde_json::from_str(&body).map_err(|e| {
GitHubApiSnafu {
message: format!(
"Failed to parse GraphQL response: {}, response: {}",
e, body
),
}
.build()
})?;
if let Some(errors) = data.errors {
return Err(GitHubApiSnafu {
message: format!(
"GraphQL request failed: {}",
errors
.iter()
.map(|error| error.message.clone())
.collect::<Vec<String>>()
.join(", ")
),
}
.build());
}
Ok(data.data)
}
}
impl Forge for GitHubForge {
type User = GitHubUser;
type MergeRequest = PullRequest;
type UserId = UserId<u64>;
fn project_id(&self) -> &str {
&self.target_project_id
}
fn source_project_id(&self) -> &str {
&self.source_project_id
}
fn target_project_id(&self) -> &str {
&self.target_project_id
}
fn base_url(&self) -> &str {
&self.base_url
}
fn project_url(&self) -> String {
let base_url = if self.base_url.starts_with("https://api.github.com") {
"https://github.com"
} else if self.base_url.contains("/api/v3") {
self.base_url.trim_end_matches("/api/v3")
} else {
&self.base_url
};
format!("{}/{}", base_url, self.target_project_id)
}
async fn current_user(&self) -> Result<Self::User> {
let user: GitHubUser = self.request(Method::GET, "/user", None::<()>).await?;
Ok(user)
}
async fn user_by_username(&self, username: &str) -> Result<Option<Self::User>> {
match self
.request::<GitHubUser>(Method::GET, format!("/users/{}", username), None::<()>)
.await
{
Ok(user) => Ok(Some(user)),
Err(Error::GitHubApi { message, .. }) if message.contains("404") => Ok(None),
Err(e) => Err(e),
}
}
async fn find_merge_request_by_source_branch(
&self,
branch: &str,
) -> Result<Option<Self::MergeRequest>> {
let source_owner = self.source_project_id.split('/').next().unwrap();
let prs: Vec<PullRequest> = self
.request(
Method::GET,
format!(
"/repos/{}/pulls?head={}:{}&state=open",
self.target_project_id,
source_owner,
urlencoding::encode(branch)
),
None::<()>,
)
.await?;
Ok(prs.into_iter().next())
}
async fn create_merge_request(
&self,
ForgeCreateMergeRequestOptions {
assignees,
description,
reviewers,
source_branch,
target_branch,
title,
open_as_draft,
remove_source_branch: _remove_source_branch,
squash: _squash,
}: ForgeCreateMergeRequestOptions<Self::UserId>,
) -> Result<Self::MergeRequest> {
let (source_owner, source_repository) = self.source_project_id.split_once('/').unwrap();
let head = if self.source_project_id != self.target_project_id {
format!("{}:{}", source_owner, source_branch)
} else {
source_branch.clone()
};
let head_repo = if self.source_project_id != self.target_project_id {
source_repository
} else {
let (_target_owner, target_repository) =
self.target_project_id.split_once('/').unwrap();
target_repository
};
let mut payload = serde_json::json!({
"title": title,
"head": head,
"head_repo": head_repo,
"base": target_branch,
"draft": open_as_draft,
});
if let Some(description) = description {
payload["body"] = serde_json::json!(description);
}
let pr: PullRequest = self
.request(
Method::POST,
format!("/repos/{}/pulls", self.target_project_id),
Some(payload),
)
.await?;
if !assignees.is_empty() {
self.add_assignees(
pr.number,
assignees.into_iter().map(|user| user.0).collect(),
)
.await?;
}
if !reviewers.is_empty() {
self.request_reviewers(
pr.number,
reviewers.into_iter().map(|user| user.0).collect(),
)
.await?;
}
Ok(pr)
}
async fn update_merge_request_base(
&self,
pr_number: Self::Id,
new_base: &str,
) -> Result<Self::MergeRequest> {
let pr: PullRequest = self
.request(
Method::PATCH,
format!("/repos/{}/pulls/{}", self.target_project_id, pr_number),
Some(serde_json::json!({
"base": new_base,
})),
)
.await?;
Ok(pr)
}
async fn update_merge_request_description(
&self,
pr_number: Self::Id,
new_description: &str,
) -> Result<Self::MergeRequest> {
let pr: PullRequest = self
.request(
Method::PATCH,
format!("/repos/{}/pulls/{}", self.target_project_id, pr_number),
Some(serde_json::json!({
"body": new_description,
})),
)
.await?;
Ok(pr)
}
async fn get_merge_request(&self, pr_number: Self::Id) -> Result<Self::MergeRequest> {
let pr: PullRequest = self
.request(
Method::GET,
format!("/repos/{}/pulls/{}", self.target_project_id, pr_number),
None::<()>,
)
.await?;
Ok(pr)
}
async fn get_approval_status(&self, pr_number: Self::Id) -> Result<ApprovalStatus> {
let pr = self.get_merge_request(pr_number).await?;
let base_branch = pr.base.ref_name;
let (reviews, required_count): (Result<Vec<Review>, _>, _) = join!(
self.request(
Method::GET,
format!(
"/repos/{}/pulls/{}/reviews",
self.target_project_id, pr_number
),
None::<()>,
),
self.get_required_approvals(&base_branch),
);
let user_reviews: Result<HashMap<u64, ReviewState>, _> = reviews.map(|reviews| {
reviews
.iter()
.filter(|review| review.submitted_at.is_some())
.sorted_by_key(|review| review.submitted_at.as_ref().unwrap())
.map(|review| (review.user.id, review.state.clone()))
.collect()
});
let approved_count = user_reviews.as_ref().map(|reviews| {
reviews
.values()
.filter(|state| matches!(state, ReviewState::Approved))
.count() as u32
});
let blocking_count = user_reviews.as_ref().map(|reviews| {
reviews
.values()
.filter(|state| matches!(state, ReviewState::ChangesRequested))
.count() as u32
});
Ok(ApprovalStatus {
approved_count: approved_count.unwrap_or(0),
required_count: *required_count.as_ref().unwrap_or(&0),
blocking_count: blocking_count.unwrap_or(0),
satisfaction: match (required_count, approved_count) {
(Ok(count), Ok(approved_count)) => {
if approved_count >= count {
ApprovalSatisfaction::Satisfied
} else {
ApprovalSatisfaction::Unsatisfied
}
}
(_, _) => ApprovalSatisfaction::Unknown,
},
})
}
async fn get_check_status(&self, pr_number: Self::Id) -> Result<CheckStatus> {
let head_sha = self.get_merge_request(pr_number).await?.head.sha;
let response: CheckRunsResponse = self
.request(
Method::GET,
format!(
"/repos/{}/commits/{}/check-runs",
self.target_project_id,
urlencoding::encode(&head_sha)
),
None::<()>,
)
.await?;
if response.total_count == 0 {
return Ok(CheckStatus::None);
}
let mut has_pending = false;
let mut has_failed = false;
for check_run in response.check_runs {
match (check_run.status, check_run.conclusion) {
(
CheckRunStatus::Completed,
Some(
CheckRunConclusion::Failure
| CheckRunConclusion::Canceled
| CheckRunConclusion::TimedOut
| CheckRunConclusion::ActionRequired,
),
) => {
has_failed = true;
}
(CheckRunStatus::Completed, _) => {}
(CheckRunStatus::Queued, _)
| (CheckRunStatus::InProgress, _)
| (CheckRunStatus::Waiting, _)
| (CheckRunStatus::Pending, _)
| (CheckRunStatus::Requested, _) => {
has_pending = true;
}
}
}
if has_failed {
Ok(CheckStatus::Failed)
} else if has_pending {
Ok(CheckStatus::Pending)
} else {
Ok(CheckStatus::Success)
}
}
async fn get_merge_request_status(&self, pr_number: Self::Id) -> Result<MergeRequestStatus> {
let (approval_status, check_status) = try_join!(
self.get_approval_status(pr_number),
self.get_check_status(pr_number),
)?;
Ok(MergeRequestStatus {
iid: pr_number.to_string(),
approval_status,
check_status,
})
}
async fn num_open_discussions(&self, pr_number: Self::Id) -> Result<DiscussionCount> {
let discussions = self.get_discussions(pr_number).await?;
Ok(discussions.iter().fold(
DiscussionCount {
all: 0,
unresolved: 0,
resolved: 0,
},
|mut acc, comment| {
acc.all += 1;
if comment.is_minimized {
acc.resolved += 1;
} else if comment.viewer_can_minimize {
acc.unresolved += 1;
}
acc
},
))
}
async fn sync_dependent_merge_requests(
&self,
_merge_request_iid: Self::Id,
_dependent_merge_request_iids: &[Self::Id],
) -> Result<bool> {
Ok(false)
}
}
impl FormatMergeRequest for GitHubForge {
type Id = u64;
fn format_merge_request_id(&self, mr_iid: Self::Id) -> String {
format!("#{}", mr_iid)
}
fn mr_name(&self) -> &'static str {
"PR"
}
}
impl GitHubForge {
async fn add_assignees(&self, pr_number: u64, assignees: Vec<u64>) -> Result<()> {
self.request::<serde_json::Value>(
Method::POST,
format!(
"/repos/{}/issues/{}/assignees",
self.target_project_id, pr_number
),
Some(serde_json::json!({
"assignees": assignees,
})),
)
.await?;
Ok(())
}
async fn request_reviewers(&self, pr_number: u64, reviewers: Vec<u64>) -> Result<()> {
self.request::<serde_json::Value>(
Method::POST,
format!(
"/repos/{}/pulls/{}/requested_reviewers",
self.target_project_id, pr_number
),
Some(serde_json::json!({
"reviewers": reviewers,
})),
)
.await?;
Ok(())
}
async fn get_required_approvals(&self, branch: &str) -> Result<u32> {
let rules: Vec<BranchRule> = self
.request(
Method::GET,
format!(
"/repos/{}/rules/branches/{}",
self.target_project_id,
urlencoding::encode(branch)
),
None::<()>,
)
.await?;
for rule in rules {
if rule.rule_type == BranchRuleType::PullRequest
&& let Some(params) = rule.parameters
&& let Some(count) = params.required_approving_review_count
{
return Ok(count);
}
}
Ok(0)
}
async fn get_discussions(
&self,
pr_number: u64,
) -> Result<Vec<graphql::GetDiscussionsQueryComment>> {
let (owner, name) = self.target_project_id.split("/").collect_tuple().unwrap();
let response: graphql::GetDiscussionsQueryResponse = self
.graphql(
r#"
query GetDiscussions($owner: String!, $name: String!, $pr_number: Int!) {
repository(owner: $owner, name: $name) {
pullRequest(number: $pr_number) {
reviews(first: 100) {
nodes {
id
comments(first: 100) {
nodes {
author {
login
}
body
createdAt
editor {
login
}
id
lastEditedAt
isMinimized
minimizedReason
publishedAt
updatedAt
url
viewerCanMinimize
}
}
}
}
comments(first: 100) {
nodes {
author {
login
}
body
createdAt
editor {
login
}
id
lastEditedAt
isMinimized
minimizedReason
publishedAt
updatedAt
url
viewerCanMinimize
}
}
}
}
}
"#,
serde_json::json!({
"owner": owner,
"name": name,
"pr_number": pr_number,
}),
)
.await?;
let pull_request = response
.repository
.ok_or(
GitHubApiSnafu {
message: format!("Repository {} not found", self.target_project_id),
}
.build(),
)?
.pull_request
.ok_or(
GitHubApiSnafu {
message: format!("Pull request {} not found", pr_number),
}
.build(),
)?;
let root_comments = &pull_request.comments.nodes;
let reviews = &pull_request.reviews.unwrap_or_default().nodes;
let review_comments = reviews
.iter()
.flat_map(|review| review.comments.nodes.iter())
.collect::<Vec<_>>();
Ok(root_comments
.iter()
.chain(review_comments.into_iter())
.cloned()
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_github_client_new() {
let client = GitHubForge::new(
"https://api.github.com".to_string(),
"owner/repo".to_string(),
"owner/repo".to_string(),
"ghp_token123".to_string(),
None::<&str>,
false,
)
.expect("Failed to create client");
assert_eq!(client.base_url, "https://api.github.com");
assert_eq!(client.source_project_id, "owner/repo");
assert_eq!(client.target_project_id, "owner/repo");
assert_eq!(client.token, "ghp_token123");
}
#[test]
fn test_project_url() {
let client = GitHubForge::new(
"https://api.github.com".to_string(),
"owner/repo".to_string(),
"owner/repo".to_string(),
"token".to_string(),
None::<&str>,
false,
)
.expect("Failed to create client");
assert_eq!(client.project_url(), "https://github.com/owner/repo");
}
#[test]
fn test_github_enterprise_url() {
let client = GitHubForge::new(
"https://github.example.com/api/v3".to_string(),
"owner/repo".to_string(),
"owner/repo".to_string(),
"token".to_string(),
None::<&str>,
false,
)
.expect("Failed to create client");
assert_eq!(
client.project_url(),
"https://github.example.com/owner/repo"
);
}
}