Skip to main content

git_bot_feedback/client/github/
mod.rs

1//! This module holds functionality specific to using Github's REST API.
2//!
3//! In the root module, we just implement the RestApiClient trait.
4//! In other (private) submodules we implement behavior specific to Github's REST API.
5
6use std::{
7    env,
8    fs::OpenOptions,
9    io::{self, Write},
10};
11
12use async_trait::async_trait;
13use reqwest::{Client, Method, Url};
14
15use crate::{
16    FileAnnotation, OutputVariable, ReviewAction, ReviewOptions, ThreadCommentOptions,
17    client::{ClientError, RestApiClient, RestApiRateLimitHeaders},
18};
19mod graphql;
20mod serde_structs;
21use serde_structs::{FullReview, PullRequestInfo, PullRequestState, ReviewDiffComment};
22mod specific_api;
23
24#[cfg(feature = "file-changes")]
25use crate::{FileDiffLines, FileFilter, LinesChangedOnly, parse_diff};
26#[cfg(feature = "file-changes")]
27use std::{collections::HashMap, path::Path};
28
29/// A structure to work with Github REST API.
30pub struct GithubApiClient {
31    /// The HTTP request client to be used for all REST API calls.
32    client: Client,
33
34    /// The CI run's event payload from the webhook that triggered the workflow.
35    pull_request: Option<PullRequestInfo>,
36
37    /// The name of the event that was triggered when running cpp_linter.
38    pub event_name: String,
39
40    /// The value of the `GITHUB_API_URL` environment variable.
41    api_url: Url,
42
43    /// The value of the `GITHUB_REPOSITORY` environment variable.
44    repo: String,
45
46    /// The value of the `GITHUB_SHA` environment variable.
47    sha: String,
48
49    /// The value of the `ACTIONS_STEP_DEBUG` environment variable.
50    pub debug_enabled: bool,
51
52    /// The response header names that describe the rate limit status.
53    rate_limit_headers: RestApiRateLimitHeaders,
54}
55
56// implement the RestApiClient trait for the GithubApiClient
57#[async_trait]
58impl RestApiClient for GithubApiClient {
59    fn client_kind(&self) -> String {
60        "github".to_string()
61    }
62
63    /// This prints a line to indicate the beginning of a related group of [`log`] statements.
64    ///
65    /// For apps' [`log`] implementations, this function's [`log::info`] output needs to have
66    /// no prefixed data.
67    /// Such behavior can be identified by the log target `"CI_LOG_GROUPING"`.
68    ///
69    /// ```
70    /// # struct MyAppLogger;
71    /// impl log::Log for MyAppLogger {
72    /// #    fn enabled(&self, metadata: &log::Metadata) -> bool {
73    /// #        log::max_level() > metadata.level()
74    /// #    }
75    ///     fn log(&self, record: &log::Record) {
76    ///         if record.target() == "CI_LOG_GROUPING" {
77    ///             println!("{}", record.args());
78    ///         } else {
79    ///             println!(
80    ///                 "[{:>5}]{}: {}",
81    ///                 record.level().as_str(),
82    ///                 record.module_path().unwrap_or_default(),
83    ///                 record.args()
84    ///             );
85    ///         }
86    ///     }
87    /// #    fn flush(&self) {}
88    /// }
89    /// ```
90    fn start_log_group(&self, name: &str) {
91        log::info!(target: "CI_LOG_GROUPING", "::group::{name}");
92    }
93
94    /// This prints a line to indicate the ending of a related group of [`log`] statements.
95    ///
96    /// See also [`GithubApiClient::start_log_group`] about special handling of
97    /// the log target `"CI_LOG_GROUPING"`.
98    fn end_log_group(&self, _name: &str) {
99        log::info!(target: "CI_LOG_GROUPING", "::endgroup::");
100    }
101
102    fn event_name(&self) -> Option<String> {
103        Some(self.event_name.clone())
104    }
105
106    fn is_debug_enabled(&self) -> bool {
107        self.debug_enabled
108    }
109
110    fn set_user_agent(&mut self, user_agent: &str) -> Result<(), ClientError> {
111        self.client = Client::builder()
112            .default_headers(Self::make_headers()?)
113            .user_agent(user_agent)
114            .build()?;
115        Ok(())
116    }
117
118    async fn post_thread_comment(&self, options: ThreadCommentOptions) -> Result<(), ClientError> {
119        env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
120        let comments_url = match &self.pull_request {
121            Some(pr_event) => {
122                if pr_event.locked {
123                    return Ok(()); // cannot comment on locked PRs
124                }
125                self.api_url.join(
126                    format!("repos/{}/issues/{}/comments", self.repo, pr_event.number).as_str(),
127                )?
128            }
129            None => self
130                .api_url
131                .join(format!("repos/{}/commits/{}/comments", self.repo, self.sha).as_str())?,
132        };
133        self.update_comment(comments_url, options).await
134    }
135
136    #[inline]
137    fn is_pr_event(&self) -> bool {
138        self.pull_request.is_some()
139    }
140
141    fn append_step_summary(&self, comment: &str) -> Result<(), ClientError> {
142        let gh_out = env::var("GITHUB_STEP_SUMMARY")
143            .map_err(|e| ClientError::env_var("GITHUB_STEP_SUMMARY", e))?;
144        // step summary MD file can be overwritten/removed in CI runners
145        match OpenOptions::new().append(true).open(gh_out) {
146            Ok(mut gh_out_file) => writeln!(&mut gh_out_file, "\n{comment}\n")
147                .map_err(|e| ClientError::io("write to GITHUB_STEP_SUMMARY file", e)),
148            Err(e) => Err(ClientError::io("open GITHUB_STEP_SUMMARY file", e)),
149        }
150    }
151
152    fn write_output_variables(&self, vars: &[OutputVariable]) -> Result<(), ClientError> {
153        if vars.is_empty() {
154            // Should probably be an error. This check is only here to prevent needlessly
155            // fetching the env var GITHUB_OUTPUT value and opening the referenced file.
156            return Ok(());
157        }
158        let gh_out =
159            env::var("GITHUB_OUTPUT").map_err(|e| ClientError::env_var("GITHUB_OUTPUT", e))?;
160        match OpenOptions::new().append(true).open(gh_out) {
161            Ok(mut gh_out_file) => {
162                for out_var in vars {
163                    out_var.validate()?;
164                    writeln!(&mut gh_out_file, "{out_var}")
165                        .map_err(|e| ClientError::io("write to GITHUB_OUTPUT file", e))?;
166                }
167                Ok(())
168            }
169            Err(e) => Err(ClientError::io("open GITHUB_OUTPUT file", e)),
170        }
171    }
172
173    fn write_file_annotations(&self, annotations: &[FileAnnotation]) -> Result<(), ClientError> {
174        if annotations.is_empty() {
175            // Should probably be an error.
176            // This check is only here to prevent needlessly locking stdout.
177            return Ok(());
178        }
179        let stdout = io::stdout();
180        let mut handle = stdout.lock();
181        for annotation in annotations {
182            writeln!(&mut handle, "{}", annotation.fmt_github())
183                .map_err(|e| ClientError::io("write to file annotation to stdout", e))?;
184        }
185        handle
186            .flush()
187            .map_err(|e| ClientError::io("flush stdout with file annotations", e))?;
188        Ok(())
189    }
190
191    #[cfg(feature = "file-changes")]
192    #[cfg_attr(docsrs, doc(cfg(feature = "file-changes")))]
193    async fn get_list_of_changed_files(
194        &self,
195        file_filter: &FileFilter,
196        lines_changed_only: &LinesChangedOnly,
197        _base_diff: Option<String>,
198        _ignore_index: bool,
199    ) -> Result<HashMap<String, FileDiffLines>, ClientError> {
200        let (url, is_pr) = match &self.pull_request {
201            Some(pr_event) => (
202                self.api_url.join(
203                    format!("repos/{}/pulls/{}/files", self.repo, pr_event.number).as_str(),
204                )?,
205                true,
206            ),
207            None => (
208                self.api_url
209                    .join(format!("repos/{}/commits/{}", self.repo, self.sha).as_str())?,
210                false,
211            ),
212        };
213        let mut url = Some(Url::parse_with_params(url.as_str(), &[("page", "1")])?);
214        let mut files: HashMap<String, FileDiffLines> = HashMap::new();
215        while let Some(ref endpoint) = url {
216            let request =
217                self.make_api_request(&self.client, endpoint.to_owned(), Method::GET, None, None)?;
218            let response = self
219                .send_api_request(&self.client, request, &self.rate_limit_headers)
220                .await
221                .map_err(|e| e.add_request_context("get list of changed files"))?;
222            url = self.try_next_page(response.headers());
223            let body = response.text().await?;
224            let files_list = if !is_pr {
225                let json_value: serde_structs::PushEventFiles = serde_json::from_str(&body)
226                    .map_err(|e| ClientError::json("deserialize list of changed files", e))?;
227                json_value.files
228            } else {
229                serde_json::from_str::<Vec<serde_structs::GithubChangedFile>>(&body)
230                    .map_err(|e| ClientError::json("deserialize list of changed files", e))?
231            };
232            for file in files_list {
233                let ext = Path::new(&file.filename).extension().unwrap_or_default();
234                if !file_filter
235                    .extensions
236                    .contains(&ext.to_string_lossy().to_string())
237                {
238                    continue;
239                }
240                if let Some(patch) = file.patch {
241                    let diff = format!(
242                        "diff --git a/{old} b/{new}\n--- a/{old}\n+++ b/{new}\n{patch}\n",
243                        old = file.previous_filename.unwrap_or(file.filename.clone()),
244                        new = file.filename,
245                    );
246                    for (name, info) in parse_diff(&diff, file_filter, lines_changed_only)? {
247                        files.entry(name).or_insert(info);
248                    }
249                } else if file.changes == 0 {
250                    // file may have been only renamed.
251                    // include it in case files-changed-only is enabled.
252                    files.entry(file.filename).or_default();
253                }
254                // else changes are too big (per git server limits) or we don't care
255            }
256        }
257        Ok(files)
258    }
259
260    async fn cull_pr_reviews(&mut self, options: &mut ReviewOptions) -> Result<(), ClientError> {
261        if let Some(pr_info) = self.pull_request.as_ref() {
262            if pr_info.locked
263                || (!options.allow_closed && pr_info.state == PullRequestState::Closed)
264            {
265                return Ok(());
266            }
267            env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
268
269            // Check existing comments to see if we can reuse any of them.
270            // This also removes duplicate comments (if any) from the `options.comments`.
271            let keep_reviews = self.check_reused_comments(options).await?;
272            // Next hide/resolve any previous reviews that are completely outdated.
273            let url = self
274                .api_url
275                .join(format!("repos/{}/pulls/{}/reviews", self.repo, pr_info.number).as_str())?;
276            self.hide_outdated_reviews(url, keep_reviews, &options.marker)
277                .await?;
278        }
279        Ok(())
280    }
281
282    async fn post_pr_review(&mut self, options: &ReviewOptions) -> Result<(), ClientError> {
283        if let Some(pr_info) = self.pull_request.as_ref() {
284            if (!options.allow_draft && pr_info.draft)
285                || (!options.allow_closed && pr_info.state == PullRequestState::Closed)
286                || pr_info.locked
287            {
288                return Ok(());
289            }
290            env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
291            let url = self
292                .api_url
293                .join(format!("repos/{}/pulls/{}/reviews", self.repo, pr_info.number).as_str())?;
294            let payload = FullReview {
295                event: match options.action {
296                    ReviewAction::Comment => String::from("COMMENT"),
297                    ReviewAction::Approve => String::from("APPROVE"),
298                    ReviewAction::RequestChanges => String::from("REQUEST_CHANGES"),
299                },
300                body: format!("{}{}", options.marker, options.summary),
301                comments: options
302                    .comments
303                    .iter()
304                    .map(ReviewDiffComment::from)
305                    .map(|mut r| {
306                        if !r.body.starts_with(&options.marker) {
307                            r.body = format!("{}{}", options.marker, r.body);
308                        }
309                        r
310                    })
311                    .collect(),
312            };
313            let request = self.make_api_request(
314                &self.client,
315                url,
316                Method::POST,
317                Some(
318                    serde_json::to_string(&payload)
319                        .map_err(|e| ClientError::json("serialize PR review payload", e))?,
320                ),
321                None,
322            )?;
323            let response = self
324                .send_api_request(&self.client, request, &self.rate_limit_headers)
325                .await;
326            match response {
327                Ok(response) => {
328                    self.log_response(response, "Failed to post PR review")
329                        .await;
330                }
331                Err(e) => {
332                    return Err(e.add_request_context("post PR review"));
333                }
334            }
335        }
336        Ok(())
337    }
338}