use std::{env, fmt::Debug, time::Duration};
use async_trait::async_trait;
use chrono::DateTime;
use reqwest::{Client, Method, Request, Response, Url, header::HeaderMap};
use crate::{FileAnnotation, OutputVariable, RestClientError, ReviewOptions, ThreadCommentOptions};
#[cfg(feature = "github")]
mod github;
#[cfg(feature = "github")]
pub use github::GithubApiClient;
mod local;
pub use local::LocalClient;
#[cfg(not(any(feature = "github", feature = "custom-git-server-impl")))]
compile_error!(
"At least one Git server implementation (eg. 'github') should be enabled via `features`"
);
#[cfg(feature = "file-changes")]
use crate::{FileDiffLines, FileFilter, LinesChangedOnly};
#[cfg(feature = "file-changes")]
use std::collections::HashMap;
pub static USER_AGENT: &str = concat!(env!("CARGO_CRATE_NAME"), "/", env!("CARGO_PKG_VERSION"));
#[derive(Debug, Clone)]
pub struct RestApiRateLimitHeaders {
pub reset: String,
pub remaining: String,
pub retry: String,
}
pub(crate) type ClientError = RestClientError;
pub(crate) const MAX_RETRIES: u8 = 5;
#[async_trait]
pub trait RestApiClient {
fn start_log_group(&self, name: &str) {
log::info!(target: "CI_LOG_GROUPING", "start_log_group: {name}");
}
fn end_log_group(&self, name: &str) {
log::info!(target: "CI_LOG_GROUPING", "end_log_group: {name}");
}
fn is_pr_event(&self) -> bool;
fn is_debug_enabled(&self) -> bool {
false
}
fn event_name(&self) -> Option<String> {
None
}
fn set_user_agent(&mut self, user_agent: &str) -> Result<(), ClientError>;
#[cfg(feature = "file-changes")]
#[cfg_attr(docsrs, doc(cfg(feature = "file-changes")))]
async fn get_list_of_changed_files(
&self,
file_filter: &FileFilter,
lines_changed_only: &LinesChangedOnly,
base_diff: Option<String>,
ignore_index: bool,
) -> Result<HashMap<String, FileDiffLines>, ClientError>;
async fn post_thread_comment(&self, options: ThreadCommentOptions) -> Result<(), ClientError>;
fn append_step_summary(&self, comment: &str) -> Result<(), ClientError> {
let _ = comment;
Ok(())
}
async fn cull_pr_reviews(&mut self, options: &mut ReviewOptions) -> Result<(), ClientError>;
async fn post_pr_review(&mut self, options: &ReviewOptions) -> Result<(), ClientError>;
fn write_output_variables(&self, vars: &[OutputVariable]) -> Result<(), ClientError>;
fn write_file_annotations(&self, annotations: &[FileAnnotation]) -> Result<(), ClientError> {
for annotation in annotations {
log::info!("{annotation:#?}");
}
Ok(())
}
fn make_api_request(
&self,
client: &Client,
url: Url,
method: Method,
data: Option<String>,
headers: Option<HeaderMap>,
) -> Result<Request, ClientError> {
let mut req = client.request(method, url);
if let Some(h) = headers {
req = req.headers(h);
}
if let Some(d) = data {
req = req.body(d);
}
req.build()
.map_err(|e| ClientError::add_request_context(ClientError::Request(e), "build request"))
}
async fn send_api_request(
&self,
client: &Client,
request: Request,
rate_limit_headers: &RestApiRateLimitHeaders,
) -> Result<Response, ClientError> {
for i in 0..MAX_RETRIES {
let response = client
.execute(request.try_clone().ok_or(ClientError::CannotCloneRequest)?)
.await?;
if [403u16, 429u16].contains(&response.status().as_u16()) {
let mut requests_remaining = None;
if let Some(remaining) = response.headers().get(&rate_limit_headers.remaining) {
requests_remaining = Some(remaining.to_str()?.parse::<i64>()?);
} else {
log::debug!("Response headers do not include remaining API usage count");
}
if requests_remaining.is_some_and(|v| v <= 0) {
if let Some(reset_value) = response.headers().get(&rate_limit_headers.reset)
&& let Some(reset) =
DateTime::from_timestamp(reset_value.to_str()?.parse::<i64>()?, 0)
{
return Err(ClientError::RateLimitPrimary(reset));
}
return Err(ClientError::RateLimitNoReset);
}
if let Some(retry_value) = response.headers().get(&rate_limit_headers.retry) {
let interval = Duration::from_secs(
retry_value.to_str()?.parse::<u64>()? + (i as u64).pow(2),
);
#[cfg(feature = "test-skip-wait-for-rate-limit")]
{
log::warn!(
"Skipped waiting {} seconds to expedite test",
interval.as_secs()
);
}
#[cfg(not(feature = "test-skip-wait-for-rate-limit"))]
{
tokio::time::sleep(interval).await;
}
continue;
}
}
return Ok(response);
}
Err(ClientError::RateLimitSecondary)
}
fn try_next_page(&self, headers: &HeaderMap) -> Option<Url> {
if let Some(links) = headers.get("link")
&& let Ok(pg_str) = links.to_str()
{
let pages = pg_str.split(", ");
for page in pages {
if page.ends_with("; rel=\"next\"") {
if let Some(link) = page.split_once(">;") {
let url = link.0.trim_start_matches("<").to_string();
if let Ok(next) = Url::parse(&url) {
return Some(next);
} else {
log::debug!("Failed to parse next page link from response header");
}
} else {
log::debug!("Response header link for pagination is malformed");
}
}
}
}
None
}
async fn log_response(&self, response: Response, context: &str) {
if let Err(e) = response.error_for_status_ref() {
log::error!("{}: {e:?}", context.to_owned());
if let Ok(text) = response.text().await {
log::error!("{text}");
}
}
}
}
pub fn init_client() -> Result<Box<dyn RestApiClient + Send + Sync>, ClientError> {
if env::var("GITHUB_ACTIONS").is_ok_and(|v| v.to_lowercase() == "true") {
Ok(Box::new(GithubApiClient::new()?))
} else {
Ok(Box::new(LocalClient))
}
}