use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use octocrab::params::repos::Reference;
use octocrab::service::middleware::retry::RetryConfig;
use octocrab::Octocrab;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::config::Config;
const GITHUB_API_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const GITHUB_API_READ_TIMEOUT: Duration = Duration::from_secs(30);
const GITHUB_API_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
const GITHUB_API_RETRY_COUNT: usize = 1;
pub struct GitHubClient {
pub octocrab: Octocrab,
pub owner: String,
pub repo: String,
api_call_tracker: Arc<ApiCallTracker>,
}
impl Clone for GitHubClient {
fn clone(&self) -> Self {
Self {
octocrab: self.octocrab.clone(),
owner: self.owner.clone(),
repo: self.repo.clone(),
api_call_tracker: self.api_call_tracker.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ApiCallStats {
pub total_requests: usize,
pub by_operation: Vec<(String, usize)>,
}
#[derive(Default)]
struct ApiCallTracker {
total_requests: AtomicUsize,
by_operation: Mutex<BTreeMap<String, usize>>,
}
impl ApiCallTracker {
fn record(&self, operation: &'static str, count: usize) {
if count == 0 {
return;
}
self.total_requests.fetch_add(count, Ordering::Relaxed);
let mut by_operation = self.by_operation.lock().unwrap_or_else(|e| e.into_inner());
*by_operation.entry(operation.to_string()).or_insert(0) += count;
}
fn snapshot(&self) -> ApiCallStats {
let by_operation = self
.by_operation
.lock()
.unwrap_or_else(|e| e.into_inner())
.iter()
.map(|(operation, count)| (operation.clone(), *count))
.collect();
ApiCallStats {
total_requests: self.total_requests.load(Ordering::Relaxed),
by_operation,
}
}
}
#[derive(Debug, Deserialize)]
struct CheckRunsResponse {
total_count: usize,
check_runs: Vec<CheckRun>,
}
#[derive(Debug, Deserialize)]
struct CheckRun {
status: String,
conclusion: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct PrActivity {
pub number: u64,
pub title: String,
pub timestamp: DateTime<Utc>,
pub url: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ReviewActivity {
pub pr_number: u64,
pub pr_title: String,
pub reviewer: String,
pub state: String,
pub timestamp: DateTime<Utc>,
pub is_received: bool, }
#[derive(Debug, Clone)]
pub struct OpenPrInfo {
pub number: u64,
pub head_branch: String,
pub base_branch: String,
pub state: String,
pub is_draft: bool,
}
#[derive(Debug, Deserialize)]
struct ReviewUser {
login: String,
}
#[derive(Debug, Deserialize)]
struct Review {
state: String,
submitted_at: Option<DateTime<Utc>>,
user: Option<ReviewUser>,
}
#[derive(Debug, Deserialize)]
struct SearchIssuesResponse {
items: Vec<SearchIssue>,
}
#[derive(Debug, Deserialize)]
struct SearchIssue {
number: u64,
title: String,
html_url: String,
created_at: DateTime<Utc>,
closed_at: Option<DateTime<Utc>>,
}
impl GitHubClient {
pub fn new(owner: &str, repo: &str, api_base_url: Option<String>) -> Result<Self> {
let token = Config::github_token().context(
"GitHub auth not configured. Use one of: `stax auth`, `stax auth --from-gh`, \
`gh auth login`, or set `STAX_GITHUB_TOKEN`.",
)?;
let mut builder = Octocrab::builder()
.personal_token(token.to_string())
.add_retry_config(RetryConfig::Simple(GITHUB_API_RETRY_COUNT))
.set_connect_timeout(Some(GITHUB_API_CONNECT_TIMEOUT))
.set_read_timeout(Some(GITHUB_API_READ_TIMEOUT))
.set_write_timeout(Some(GITHUB_API_WRITE_TIMEOUT));
if let Some(api_base) = api_base_url {
builder = builder
.base_uri(api_base)
.context("Failed to set GitHub API base URL")?;
}
let octocrab = builder.build().context("Failed to create GitHub client")?;
Ok(Self {
octocrab,
owner: owner.to_string(),
repo: repo.to_string(),
api_call_tracker: Arc::new(ApiCallTracker::default()),
})
}
#[cfg(test)]
pub fn with_octocrab(octocrab: Octocrab, owner: &str, repo: &str) -> Self {
Self {
octocrab,
owner: owner.to_string(),
repo: repo.to_string(),
api_call_tracker: Arc::new(ApiCallTracker::default()),
}
}
pub fn api_call_stats(&self) -> ApiCallStats {
self.api_call_tracker.snapshot()
}
pub(crate) fn record_api_call(&self, operation: &'static str) {
self.api_call_tracker.record(operation, 1);
}
pub async fn combined_status_state(&self, commit_sha: &str) -> Result<Option<String>> {
let commit_status = self
.octocrab
.repos(&self.owner, &self.repo)
.combined_status_for_ref(&Reference::Branch(commit_sha.to_string()))
.await
.ok();
let check_runs_status = self.get_check_runs_status(commit_sha).await.ok().flatten();
match (check_runs_status, commit_status) {
(Some(cr_status), _) => Ok(Some(cr_status)),
(None, Some(status)) => Ok(Some(format!("{:?}", status.state).to_lowercase())),
(None, None) => Ok(None),
}
}
async fn get_check_runs_status(&self, commit_sha: &str) -> Result<Option<String>> {
let url = format!(
"/repos/{}/{}/commits/{}/check-runs",
self.owner, self.repo, commit_sha
);
let response: CheckRunsResponse = self.octocrab.get(&url, None::<&()>).await?;
if response.total_count == 0 {
return Ok(None); }
let mut has_pending = false;
let mut has_failure = false;
let mut all_success = true;
for run in &response.check_runs {
match run.status.as_str() {
"completed" => match run.conclusion.as_deref() {
Some("success") | Some("skipped") | Some("neutral") => {}
Some("failure")
| Some("timed_out")
| Some("cancelled")
| Some("action_required") => {
has_failure = true;
all_success = false;
}
_ => {
all_success = false;
}
},
"queued" | "in_progress" | "waiting" | "requested" | "pending" => {
has_pending = true;
all_success = false;
}
_ => {
all_success = false;
}
}
}
if has_failure {
Ok(Some("failure".to_string()))
} else if has_pending {
Ok(Some("pending".to_string()))
} else if all_success {
Ok(Some("success".to_string()))
} else {
Ok(Some("pending".to_string())) }
}
pub async fn get_current_user(&self) -> Result<String> {
let user = self.octocrab.current().user().await?;
Ok(user.login)
}
pub async fn get_recent_merged_prs(
&self,
hours: i64,
username: &str,
) -> Result<Vec<PrActivity>> {
let since = Utc::now() - chrono::Duration::hours(hours);
let url = format!(
"/search/issues?q=repo:{}/{}+author:{}+is:pr+is:merged&sort=updated&order=desc&per_page=30",
self.owner, self.repo, username
);
let response: SearchIssuesResponse = self.octocrab.get(&url, None::<&()>).await?;
let merged: Vec<PrActivity> = response
.items
.into_iter()
.filter_map(|issue| {
let closed_at = issue.closed_at?;
if closed_at < since {
return None;
}
Some(PrActivity {
number: issue.number,
title: issue.title,
timestamp: closed_at,
url: issue.html_url,
})
})
.collect();
Ok(merged)
}
pub async fn get_recent_opened_prs(
&self,
hours: i64,
username: &str,
) -> Result<Vec<PrActivity>> {
let since = Utc::now() - chrono::Duration::hours(hours);
let url = format!(
"/search/issues?q=repo:{}/{}+author:{}+is:pr&sort=created&order=desc&per_page=30",
self.owner, self.repo, username
);
let response: SearchIssuesResponse = self.octocrab.get(&url, None::<&()>).await?;
let opened: Vec<PrActivity> = response
.items
.into_iter()
.filter(|issue| issue.created_at >= since)
.map(|issue| PrActivity {
number: issue.number,
title: issue.title,
timestamp: issue.created_at,
url: issue.html_url,
})
.collect();
Ok(opened)
}
pub async fn get_reviews_received(
&self,
hours: i64,
username: &str,
) -> Result<Vec<ReviewActivity>> {
let since = Utc::now() - chrono::Duration::hours(hours);
let url = format!(
"/search/issues?q=repo:{}/{}+author:{}+is:pr+is:open&per_page=20",
self.owner, self.repo, username
);
let response: SearchIssuesResponse = self.octocrab.get(&url, None::<&()>).await?;
let mut reviews = Vec::new();
for issue in response.items {
let reviews_url = format!(
"/repos/{}/{}/pulls/{}/reviews",
self.owner, self.repo, issue.number
);
let pr_reviews: Vec<Review> = self
.octocrab
.get(&reviews_url, None::<&()>)
.await
.unwrap_or_default();
for review in pr_reviews {
if let Some(submitted) = review.submitted_at {
if submitted >= since {
if let Some(reviewer) = review.user {
if reviewer.login != username {
reviews.push(ReviewActivity {
pr_number: issue.number,
pr_title: issue.title.clone(),
reviewer: reviewer.login,
state: review.state,
timestamp: submitted,
is_received: true,
});
}
}
}
}
}
}
Ok(reviews)
}
pub async fn get_reviews_given(
&self,
_hours: i64,
_username: &str,
) -> Result<Vec<ReviewActivity>> {
Ok(vec![])
}
pub async fn get_user_open_prs(&self, username: &str) -> Result<Vec<OpenPrInfo>> {
let url = format!(
"/search/issues?q=repo:{}/{}+author:{}+is:pr+is:open&per_page=100",
self.owner, self.repo, username
);
let response: SearchIssuesResponse = self
.octocrab
.get(&url, None::<&()>)
.await
.context("Failed to search PRs")?;
let mut results = Vec::new();
for issue in response.items {
let pr = self
.octocrab
.pulls(&self.owner, &self.repo)
.get(issue.number)
.await;
if let Ok(pr) = pr {
results.push(OpenPrInfo {
number: pr.number,
head_branch: pr.head.ref_field.clone(),
base_branch: pr.base.ref_field.clone(),
state: "OPEN".to_string(),
is_draft: pr.draft.unwrap_or(false),
});
}
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn create_test_client(server: &MockServer) -> GitHubClient {
let octocrab = Octocrab::builder()
.base_uri(server.uri())
.unwrap()
.personal_token("test-token".to_string())
.build()
.unwrap();
GitHubClient::with_octocrab(octocrab, "test-owner", "test-repo")
}
#[tokio::test]
async fn test_check_runs_all_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 2,
"check_runs": [
{"status": "completed", "conclusion": "success"},
{"status": "completed", "conclusion": "success"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("success".to_string()));
}
#[tokio::test]
async fn test_check_runs_with_failure() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 3,
"check_runs": [
{"status": "completed", "conclusion": "success"},
{"status": "completed", "conclusion": "failure"},
{"status": "completed", "conclusion": "success"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("failure".to_string()));
}
#[tokio::test]
async fn test_check_runs_with_pending() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 2,
"check_runs": [
{"status": "completed", "conclusion": "success"},
{"status": "in_progress", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_queued() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "queued", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_waiting() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "waiting", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_no_checks() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 0,
"check_runs": []
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, None);
}
#[tokio::test]
async fn test_check_runs_skipped_and_neutral() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 3,
"check_runs": [
{"status": "completed", "conclusion": "success"},
{"status": "completed", "conclusion": "skipped"},
{"status": "completed", "conclusion": "neutral"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("success".to_string()));
}
#[tokio::test]
async fn test_check_runs_timed_out() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "completed", "conclusion": "timed_out"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("failure".to_string()));
}
#[tokio::test]
async fn test_check_runs_cancelled() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "completed", "conclusion": "cancelled"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("failure".to_string()));
}
#[tokio::test]
async fn test_check_runs_action_required() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "completed", "conclusion": "action_required"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("failure".to_string()));
}
#[tokio::test]
async fn test_check_runs_unknown_conclusion() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "completed", "conclusion": "unknown_state"}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_unknown_status() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "some_unknown_status", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_requested_status() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "requested", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_check_runs_pending_status() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/repos/test-owner/test-repo/commits/abc123/check-runs",
))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"total_count": 1,
"check_runs": [
{"status": "pending", "conclusion": null}
]
})))
.mount(&mock_server)
.await;
let client = create_test_client(&mock_server).await;
let status = client.get_check_runs_status("abc123").await.unwrap();
assert_eq!(status, Some("pending".to_string()));
}
#[tokio::test]
async fn test_with_octocrab() {
let mock_server = MockServer::start().await;
let octocrab = Octocrab::builder()
.base_uri(mock_server.uri())
.unwrap()
.personal_token("test-token".to_string())
.build()
.unwrap();
let client = GitHubClient::with_octocrab(octocrab, "owner", "repo");
assert_eq!(client.owner, "owner");
assert_eq!(client.repo, "repo");
}
#[test]
fn test_check_run_response_deserialization() {
let json = r#"{
"total_count": 2,
"check_runs": [
{"status": "completed", "conclusion": "success"},
{"status": "in_progress", "conclusion": null}
]
}"#;
let response: CheckRunsResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.total_count, 2);
assert_eq!(response.check_runs.len(), 2);
assert_eq!(response.check_runs[0].status, "completed");
assert_eq!(
response.check_runs[0].conclusion,
Some("success".to_string())
);
assert_eq!(response.check_runs[1].status, "in_progress");
assert_eq!(response.check_runs[1].conclusion, None);
}
#[test]
fn test_check_run_deserialization() {
let json = r#"{"status": "completed", "conclusion": "failure"}"#;
let check_run: CheckRun = serde_json::from_str(json).unwrap();
assert_eq!(check_run.status, "completed");
assert_eq!(check_run.conclusion, Some("failure".to_string()));
}
#[test]
fn test_github_client_clone() {
}
}