Skip to main content

git_bot_feedback/client/github/
mod.rs

1//! This module holds functionality specific to using Github's REST API.
2//!
3//! In the root module, we just implement the RestApiClient trait.
4//! In other (private) submodules we implement behavior specific to Github's REST API.
5
6use std::{
7    env,
8    fs::OpenOptions,
9    io::{self, Write},
10};
11
12use async_trait::async_trait;
13use reqwest::{Client, Method, Url};
14
15use crate::{
16    FileAnnotation, OutputVariable, ReviewAction, ReviewOptions, ThreadCommentOptions,
17    client::{ClientError, RestApiClient, RestApiRateLimitHeaders},
18};
19mod graphql;
20mod serde_structs;
21use serde_structs::{FullReview, PullRequestInfo, PullRequestState, ReviewDiffComment};
22mod specific_api;
23
24#[cfg(feature = "file-changes")]
25use crate::{FileDiffLines, FileFilter, LinesChangedOnly, parse_diff};
26#[cfg(feature = "file-changes")]
27use std::{collections::HashMap, path::Path};
28
29/// A structure to work with Github REST API.
30pub struct GithubApiClient {
31    /// The HTTP request client to be used for all REST API calls.
32    client: Client,
33
34    /// The CI run's event payload from the webhook that triggered the workflow.
35    pull_request: Option<PullRequestInfo>,
36
37    /// The name of the event that was triggered when running cpp_linter.
38    pub event_name: String,
39
40    /// The value of the `GITHUB_API_URL` environment variable.
41    api_url: Url,
42
43    /// The value of the `GITHUB_REPOSITORY` environment variable.
44    repo: String,
45
46    /// The value of the `GITHUB_SHA` environment variable.
47    sha: String,
48
49    /// The value of the `ACTIONS_STEP_DEBUG` environment variable.
50    pub debug_enabled: bool,
51
52    /// The response header names that describe the rate limit status.
53    rate_limit_headers: RestApiRateLimitHeaders,
54}
55
56// implement the RestApiClient trait for the GithubApiClient
57#[async_trait]
58impl RestApiClient for GithubApiClient {
59    /// This prints a line to indicate the beginning of a related group of [`log`] statements.
60    ///
61    /// For apps' [`log`] implementations, this function's [`log::info`] output needs to have
62    /// no prefixed data.
63    /// Such behavior can be identified by the log target `"CI_LOG_GROUPING"`.
64    ///
65    /// ```
66    /// # struct MyAppLogger;
67    /// impl log::Log for MyAppLogger {
68    /// #    fn enabled(&self, metadata: &log::Metadata) -> bool {
69    /// #        log::max_level() > metadata.level()
70    /// #    }
71    ///     fn log(&self, record: &log::Record) {
72    ///         if record.target() == "CI_LOG_GROUPING" {
73    ///             println!("{}", record.args());
74    ///         } else {
75    ///             println!(
76    ///                 "[{:>5}]{}: {}",
77    ///                 record.level().as_str(),
78    ///                 record.module_path().unwrap_or_default(),
79    ///                 record.args()
80    ///             );
81    ///         }
82    ///     }
83    /// #    fn flush(&self) {}
84    /// }
85    /// ```
86    fn start_log_group(&self, name: &str) {
87        log::info!(target: "CI_LOG_GROUPING", "::group::{name}");
88    }
89
90    /// This prints a line to indicate the ending of a related group of [`log`] statements.
91    ///
92    /// See also [`GithubApiClient::start_log_group`] about special handling of
93    /// the log target `"CI_LOG_GROUPING"`.
94    fn end_log_group(&self, _name: &str) {
95        log::info!(target: "CI_LOG_GROUPING", "::endgroup::");
96    }
97
98    fn event_name(&self) -> Option<String> {
99        Some(self.event_name.clone())
100    }
101
102    fn is_debug_enabled(&self) -> bool {
103        self.debug_enabled
104    }
105
106    fn set_user_agent(&mut self, user_agent: &str) -> Result<(), ClientError> {
107        self.client = Client::builder()
108            .default_headers(Self::make_headers()?)
109            .user_agent(user_agent)
110            .build()?;
111        Ok(())
112    }
113
114    async fn post_thread_comment(&self, options: ThreadCommentOptions) -> Result<(), ClientError> {
115        env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
116        let comments_url = match &self.pull_request {
117            Some(pr_event) => {
118                if pr_event.locked {
119                    return Ok(()); // cannot comment on locked PRs
120                }
121                self.api_url.join(
122                    format!("repos/{}/issues/{}/comments", self.repo, pr_event.number).as_str(),
123                )?
124            }
125            None => self
126                .api_url
127                .join(format!("repos/{}/commits/{}/comments", self.repo, self.sha).as_str())?,
128        };
129        self.update_comment(comments_url, options).await
130    }
131
132    #[inline]
133    fn is_pr_event(&self) -> bool {
134        self.pull_request.is_some()
135    }
136
137    fn append_step_summary(&self, comment: &str) -> Result<(), ClientError> {
138        let gh_out = env::var("GITHUB_STEP_SUMMARY")
139            .map_err(|e| ClientError::env_var("GITHUB_STEP_SUMMARY", e))?;
140        // step summary MD file can be overwritten/removed in CI runners
141        match OpenOptions::new().append(true).open(gh_out) {
142            Ok(mut gh_out_file) => writeln!(&mut gh_out_file, "\n{comment}\n")
143                .map_err(|e| ClientError::io("write to GITHUB_STEP_SUMMARY file", e)),
144            Err(e) => Err(ClientError::io("open GITHUB_STEP_SUMMARY file", e)),
145        }
146    }
147
148    fn write_output_variables(&self, vars: &[OutputVariable]) -> Result<(), ClientError> {
149        if vars.is_empty() {
150            // Should probably be an error. This check is only here to prevent needlessly
151            // fetching the env var GITHUB_OUTPUT value and opening the referenced file.
152            return Ok(());
153        }
154        let gh_out =
155            env::var("GITHUB_OUTPUT").map_err(|e| ClientError::env_var("GITHUB_OUTPUT", e))?;
156        match OpenOptions::new().append(true).open(gh_out) {
157            Ok(mut gh_out_file) => {
158                for out_var in vars {
159                    out_var.validate()?;
160                    writeln!(&mut gh_out_file, "{out_var}\n")
161                        .map_err(|e| ClientError::io("write to GITHUB_OUTPUT file", e))?;
162                }
163                Ok(())
164            }
165            Err(e) => Err(ClientError::io("open GITHUB_OUTPUT file", e)),
166        }
167    }
168
169    fn write_file_annotations(&self, annotations: &[FileAnnotation]) -> Result<(), ClientError> {
170        if annotations.is_empty() {
171            // Should probably be an error.
172            // This check is only here to prevent needlessly locking stdout.
173            return Ok(());
174        }
175        let stdout = io::stdout();
176        let mut handle = stdout.lock();
177        for annotation in annotations {
178            writeln!(&mut handle, "{annotation}\n")
179                .map_err(|e| ClientError::io("write to file annotation to stdout", e))?;
180        }
181        handle
182            .flush()
183            .map_err(|e| ClientError::io("flush stdout with file annotations", e))?;
184        Ok(())
185    }
186
187    #[cfg(feature = "file-changes")]
188    #[cfg_attr(docsrs, doc(cfg(feature = "file-changes")))]
189    async fn get_list_of_changed_files(
190        &self,
191        file_filter: &FileFilter,
192        lines_changed_only: &LinesChangedOnly,
193        _base_diff: Option<String>,
194        _ignore_index: bool,
195    ) -> Result<HashMap<String, FileDiffLines>, ClientError> {
196        let (url, is_pr) = match &self.pull_request {
197            Some(pr_event) => (
198                self.api_url.join(
199                    format!("repos/{}/pulls/{}/files", self.repo, pr_event.number).as_str(),
200                )?,
201                true,
202            ),
203            None => (
204                self.api_url
205                    .join(format!("repos/{}/commits/{}", self.repo, self.sha).as_str())?,
206                false,
207            ),
208        };
209        let mut url = Some(Url::parse_with_params(url.as_str(), &[("page", "1")])?);
210        let mut files: HashMap<String, FileDiffLines> = HashMap::new();
211        while let Some(ref endpoint) = url {
212            let request =
213                self.make_api_request(&self.client, endpoint.to_owned(), Method::GET, None, None)?;
214            let response = self
215                .send_api_request(&self.client, request, &self.rate_limit_headers)
216                .await
217                .map_err(|e| e.add_request_context("get list of changed files"))?;
218            url = self.try_next_page(response.headers());
219            let body = response.text().await?;
220            let files_list = if !is_pr {
221                let json_value: serde_structs::PushEventFiles = serde_json::from_str(&body)
222                    .map_err(|e| ClientError::json("deserialize list of changed files", e))?;
223                json_value.files
224            } else {
225                serde_json::from_str::<Vec<serde_structs::GithubChangedFile>>(&body)
226                    .map_err(|e| ClientError::json("deserialize list of changed files", e))?
227            };
228            for file in files_list {
229                let ext = Path::new(&file.filename).extension().unwrap_or_default();
230                if !file_filter
231                    .extensions
232                    .contains(&ext.to_string_lossy().to_string())
233                {
234                    continue;
235                }
236                if let Some(patch) = file.patch {
237                    let diff = format!(
238                        "diff --git a/{old} b/{new}\n--- a/{old}\n+++ b/{new}\n{patch}\n",
239                        old = file.previous_filename.unwrap_or(file.filename.clone()),
240                        new = file.filename,
241                    );
242                    for (name, info) in parse_diff(&diff, file_filter, lines_changed_only)? {
243                        files.entry(name).or_insert(info);
244                    }
245                } else if file.changes == 0 {
246                    // file may have been only renamed.
247                    // include it in case files-changed-only is enabled.
248                    files.entry(file.filename).or_default();
249                }
250                // else changes are too big (per git server limits) or we don't care
251            }
252        }
253        Ok(files)
254    }
255
256    async fn cull_pr_reviews(&mut self, options: &mut ReviewOptions) -> Result<(), ClientError> {
257        if let Some(pr_info) = self.pull_request.as_ref() {
258            if pr_info.locked
259                || (!options.allow_closed && pr_info.state == PullRequestState::Closed)
260            {
261                return Ok(());
262            }
263            env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
264
265            // Check existing comments to see if we can reuse any of them.
266            // This also removes duplicate comments (if any) from the `options.comments`.
267            let keep_reviews = self.check_reused_comments(options).await?;
268            // Next hide/resolve any previous reviews that are completely outdated.
269            let url = self
270                .api_url
271                .join(format!("repos/{}/pulls/{}/reviews", self.repo, pr_info.number).as_str())?;
272            self.hide_outdated_reviews(url, keep_reviews, &options.marker)
273                .await?;
274        }
275        Ok(())
276    }
277
278    async fn post_pr_review(&mut self, options: &ReviewOptions) -> Result<(), ClientError> {
279        if let Some(pr_info) = self.pull_request.as_ref() {
280            if (!options.allow_draft && pr_info.draft)
281                || (!options.allow_closed && pr_info.state == PullRequestState::Closed)
282                || pr_info.locked
283            {
284                return Ok(());
285            }
286            env::var("GITHUB_TOKEN").map_err(|e| ClientError::env_var("GITHUB_TOKEN", e))?;
287            let url = self
288                .api_url
289                .join(format!("repos/{}/pulls/{}/reviews", self.repo, pr_info.number).as_str())?;
290            let payload = FullReview {
291                event: match options.action {
292                    ReviewAction::Comment => String::from("COMMENT"),
293                    ReviewAction::Approve => String::from("APPROVE"),
294                    ReviewAction::RequestChanges => String::from("REQUEST_CHANGES"),
295                },
296                body: format!("{}{}", options.marker, options.summary),
297                comments: options
298                    .comments
299                    .iter()
300                    .map(ReviewDiffComment::from)
301                    .map(|mut r| {
302                        if !r.body.starts_with(&options.marker) {
303                            r.body = format!("{}{}", options.marker, r.body);
304                        }
305                        r
306                    })
307                    .collect(),
308            };
309            let request = self.make_api_request(
310                &self.client,
311                url,
312                Method::POST,
313                Some(
314                    serde_json::to_string(&payload)
315                        .map_err(|e| ClientError::json("serialize PR review payload", e))?,
316                ),
317                None,
318            )?;
319            let response = self
320                .send_api_request(&self.client, request, &self.rate_limit_headers)
321                .await;
322            match response {
323                Ok(response) => {
324                    self.log_response(response, "Failed to post PR review")
325                        .await;
326                }
327                Err(e) => {
328                    return Err(e.add_request_context("post PR review"));
329                }
330            }
331        }
332        Ok(())
333    }
334}