Skip to main content

git_bot_feedback/client/
mod.rs

1//! A module to contain traits and structs that are needed by the rest of the git-bot-feedback crate's API.
2use std::{env, fmt::Debug, time::Duration};
3
4use async_trait::async_trait;
5use chrono::DateTime;
6use reqwest::{Client, Method, Request, Response, Url, header::HeaderMap};
7
8use crate::{FileAnnotation, OutputVariable, RestClientError, ReviewOptions, ThreadCommentOptions};
9
10#[cfg(feature = "github")]
11mod github;
12#[cfg(feature = "github")]
13pub use github::GithubApiClient;
14
15#[cfg(not(any(feature = "github", feature = "custom-git-server-impl")))]
16compile_error!(
17    "At least one Git server implementation (eg. 'github') should be enabled via `features`"
18);
19
20#[cfg(feature = "file-changes")]
21use crate::{FileDiffLines, FileFilter, LinesChangedOnly, parse_diff};
22#[cfg(feature = "file-changes")]
23use std::{collections::HashMap, process::Command};
24
25/// The User-Agent header value included in all HTTP requests.
26pub static USER_AGENT: &str = concat!(env!("CARGO_CRATE_NAME"), "/", env!("CARGO_PKG_VERSION"));
27
28/// A structure to contain the different forms of headers that
29/// describe a REST API's rate limit status.
30#[derive(Debug, Clone)]
31pub struct RestApiRateLimitHeaders {
32    /// The header key of the rate limit's reset time.
33    pub reset: String,
34    /// The header key of the rate limit's remaining attempts.
35    pub remaining: String,
36    /// The header key of the rate limit's "backoff" time interval.
37    pub retry: String,
38}
39
40/// The [`Result::Err`] type returned for fallible functions in this trait.
41pub(crate) type ClientError = RestClientError;
42
43/// The number of attempts made when contending a secondary rate limit in REST API requests.
44pub(crate) const MAX_RETRIES: u8 = 5;
45
46/// A custom trait that templates necessary functionality with a Git server's REST API.
47#[async_trait]
48pub trait RestApiClient {
49    /// This prints a line to indicate the beginning of a related group of log statements.
50    fn start_log_group(&self, name: &str) {
51        log::info!(target: "CI_LOG_GROUPING", "start_log_group: {name}");
52    }
53
54    /// This prints a line to indicate the ending of a related group of log statements.
55    fn end_log_group(&self, name: &str) {
56        log::info!(target: "CI_LOG_GROUPING", "end_log_group: {name}");
57    }
58
59    /// Is the current CI event **trigger** a Pull Request?
60    ///
61    /// This **will not** check if a push event's instigating commit is part of any PR.
62    fn is_pr_event(&self) -> bool;
63
64    /// A way to get the list of changed files in the context of the CI event.
65    ///
66    /// This method will parse diff blobs and return a list of changed files.
67    ///
68    /// The default implementation uses `git diff` to get the list of changed files.
69    /// So, the default implementation requires `git` installed and a non-shallow checkout.
70    ///
71    /// Other implementations use the Git server's REST API to get the list of changed files.
72    #[cfg(feature = "file-changes")]
73    #[cfg_attr(docsrs, doc(cfg(feature = "file-changes")))]
74    async fn get_list_of_changed_files(
75        &self,
76        file_filter: &FileFilter,
77        lines_changed_only: &LinesChangedOnly,
78        base_diff: Option<String>,
79        ignore_index: bool,
80    ) -> Result<HashMap<String, FileDiffLines>, ClientError> {
81        let git_status = if ignore_index {
82            0
83        } else {
84            match Command::new("git").args(["status", "--short"]).output() {
85                Err(e) => {
86                    return Err(ClientError::io("invoke `git status`", e));
87                }
88                Ok(output) => {
89                    if output.status.success() {
90                        String::from_utf8_lossy(&output.stdout)
91                            .to_string()
92                            // trim last newline to prevent an extra empty line being counted as a changed file
93                            .trim_end_matches('\n')
94                            .lines()
95                            // we only care about staged changes
96                            .filter(|l| !l.starts_with(' '))
97                            .count()
98                    } else {
99                        let err_msg = String::from_utf8_lossy(&output.stderr).to_string();
100                        return Err(ClientError::GitCommand(err_msg));
101                    }
102                }
103            }
104        };
105        let mut diff_args = vec!["diff".to_string()];
106        if git_status != 0 {
107            // There are changes in the working directory.
108            // So, compare include the staged changes.
109            diff_args.push("--staged".to_string());
110        }
111        if let Some(base) = base_diff {
112            match Command::new("git")
113                .args(["rev-parse", base.as_str()])
114                .output()
115            {
116                Err(e) => {
117                    return Err(ClientError::Io {
118                        task: format!("invoke `git rev-parse {base}` to validate reference"),
119                        source: e,
120                    });
121                }
122                Ok(output) => {
123                    if output.status.success() {
124                        diff_args.push(base);
125                    } else if base.chars().all(|c| c.is_ascii_digit()) {
126                        // if all chars form a decimal number, then
127                        // try using it as a number of parents from HEAD
128                        diff_args.push(format!("HEAD~{base}"));
129                        // note, if still not a valid git reference, then
130                        // the error will be raised by the `git diff` command later
131                    } else {
132                        let err_msg = String::from_utf8_lossy(&output.stderr).to_string();
133                        // Given diff base did not resolve to a valid git reference
134                        return Err(ClientError::GitCommand(err_msg));
135                    }
136                }
137            }
138        } else if git_status == 0 {
139            // No base diff provided and there are no staged changes,
140            // just get the diff of the last commit.
141            diff_args.push("HEAD~1".to_string());
142        }
143        match Command::new("git").args(&diff_args).output() {
144            Err(e) => Err(ClientError::Io {
145                task: format!("invoke `git {}`", diff_args.join(" ")),
146                source: e,
147            }),
148            Ok(output) => {
149                if output.status.success() {
150                    let diff_str = String::from_utf8_lossy(&output.stdout).to_string();
151                    let files = parse_diff(&diff_str, file_filter, lines_changed_only);
152                    Ok(files)
153                } else {
154                    let err_msg = String::from_utf8_lossy(&output.stderr).to_string();
155                    Err(ClientError::GitCommand(err_msg))
156                }
157            }
158        }
159    }
160
161    /// A way to post feedback to the Git server's GUI.
162    ///
163    /// The given [`ThreadCommentOptions::comment`] should be compliant with
164    /// the Git server's requirements (ie. the comment length is within acceptable limits).
165    async fn post_thread_comment(&self, options: ThreadCommentOptions) -> Result<(), ClientError>;
166
167    /// Appends a given comment to the CI workflow's summary page.
168    ///
169    /// This is the least obtrusive and recommended for push events.
170    /// Not all Git servers natively support this type of feedback.
171    /// GitHub and Gitea are known to support this.
172    /// For all other git servers, this is a non-op returning [`Ok`]
173    fn append_step_summary(&self, comment: &str) -> Result<(), ClientError> {
174        let _ = comment;
175        Ok(())
176    }
177
178    /// Resolve outdated PR review comments and remove duplicate/reused comments.
179    ///
180    /// This should be used before [`Self::post_pr_review()`] to avoid posting duplicates of existing comments.
181    /// The [`ReviewOptions::comments`] will be modified to only include comments that should be posted for the current PR review.
182    /// After calling this function, the [`ReviewOptions::summary`] can be made to reflect the actual review being posted.
183    ///
184    /// The [`ReviewOptions::marker`] is used to identify comments from this software.
185    /// The [`ReviewOptions::delete_review_comments`] flag will delete outdated review comments.
186    /// The [`ReviewOptions::delete_review_comments`] flag does not apply to review summary comments nor
187    /// threads of discussion within a review.
188    /// A review summary comment will only be hidden/collapsed when all comments in the corresponding
189    /// review are resolved.
190    ///
191    /// This function does nothing for non-PR events.
192    async fn cull_pr_reviews(&mut self, options: &mut ReviewOptions) -> Result<(), ClientError>;
193
194    /// Post a PR review based on the given options.
195    ///
196    /// This is expected to be used after calling [`Self::cull_pr_reviews()`] to
197    /// avoid posting duplicates of existing comments. Once the duplicates are filtered out,
198    /// the [`ReviewOptions::summary`] can be made to reflect the actual review being posted.
199    ///
200    /// This function does nothing for non-PR events.
201    async fn post_pr_review(&mut self, options: &ReviewOptions) -> Result<(), ClientError>;
202
203    /// Sets the given `vars` as output variables.
204    ///
205    /// These variables are designed to be consumed by other steps in the CI workflow.
206    fn write_output_variables(&self, vars: &[OutputVariable]) -> Result<(), ClientError>;
207
208    /// Sets the given `annotations` as file annotations.
209    ///
210    /// Not all Git servers support this on their free tiers, namely GitLab.
211    fn write_file_annotations(&self, annotations: &[FileAnnotation]) -> Result<(), ClientError> {
212        println!("{annotations:#?}");
213        Ok(())
214    }
215
216    /// Construct a HTTP request to be sent.
217    ///
218    /// The idea here is that this method is called before [`send_api_request()`].
219    /// ```ignore
220    /// let request = Self::make_api_request(
221    ///     &self.client,
222    ///     Url::parse("https://example.com").unwrap(),
223    ///     Method::GET,
224    ///     None,
225    ///     None,
226    /// ).unwrap();
227    /// let response = send_api_request(&self.client, request, &self.rest_api_headers);
228    /// match response.await {
229    ///     Ok(res) => todo!(handle response),
230    ///     Err(e) => todo!(handle failure),
231    /// }
232    /// ```
233    fn make_api_request(
234        &self,
235        client: &Client,
236        url: Url,
237        method: Method,
238        data: Option<String>,
239        headers: Option<HeaderMap>,
240    ) -> Result<Request, ClientError> {
241        let mut req = client.request(method, url);
242        if let Some(h) = headers {
243            req = req.headers(h);
244        }
245        if let Some(d) = data {
246            req = req.body(d);
247        }
248        req.build()
249            .map_err(|e| ClientError::add_request_context(ClientError::Request(e), "build request"))
250    }
251
252    /// A convenience function to send HTTP requests and respect a REST API rate limits.
253    ///
254    /// This method respects both primary and secondary rate limits.
255    /// In the event where the secondary rate limits is reached,
256    /// this function will wait for a time interval (if specified by the server) and retry afterward.
257    async fn send_api_request(
258        &self,
259        client: &Client,
260        request: Request,
261        rate_limit_headers: &RestApiRateLimitHeaders,
262    ) -> Result<Response, ClientError> {
263        for i in 0..MAX_RETRIES {
264            let response = client
265                .execute(request.try_clone().ok_or(ClientError::CannotCloneRequest)?)
266                .await?;
267            if [403u16, 429u16].contains(&response.status().as_u16()) {
268                // rate limit may have been exceeded
269
270                // check if primary rate limit was violated
271                let mut requests_remaining = None;
272                if let Some(remaining) = response.headers().get(&rate_limit_headers.remaining) {
273                    requests_remaining = Some(remaining.to_str()?.parse::<i64>()?);
274                } else {
275                    // NOTE: I guess it is sometimes valid for a response to
276                    // not include remaining rate limit attempts
277                    log::debug!("Response headers do not include remaining API usage count");
278                }
279                if requests_remaining.is_some_and(|v| v <= 0) {
280                    if let Some(reset_value) = response.headers().get(&rate_limit_headers.reset)
281                        && let Some(reset) =
282                            DateTime::from_timestamp(reset_value.to_str()?.parse::<i64>()?, 0)
283                    {
284                        return Err(ClientError::RateLimitPrimary(reset));
285                    }
286                    return Err(ClientError::RateLimitNoReset);
287                }
288
289                // check if secondary rate limit is violated. If so, then backoff and try again.
290                if let Some(retry_value) = response.headers().get(&rate_limit_headers.retry) {
291                    let interval = Duration::from_secs(
292                        retry_value.to_str()?.parse::<u64>()? + (i as u64).pow(2),
293                    );
294                    #[cfg(feature = "test-skip-wait-for-rate-limit")]
295                    {
296                        // Output a log statement to use the `interval` variable.
297                        log::warn!(
298                            "Skipped waiting {} seconds to expedite test",
299                            interval.as_secs()
300                        );
301                    }
302                    #[cfg(not(feature = "test-skip-wait-for-rate-limit"))]
303                    {
304                        tokio::time::sleep(interval).await;
305                    }
306                    continue;
307                }
308            }
309            return Ok(response);
310        }
311        Err(ClientError::RateLimitSecondary)
312    }
313
314    /// Gets the URL for the next page from the headers in a paginated response.
315    ///
316    /// Returns [`None`] if current response is the last page.
317    fn try_next_page(&self, headers: &HeaderMap) -> Option<Url> {
318        if let Some(links) = headers.get("link")
319            && let Ok(pg_str) = links.to_str()
320        {
321            let pages = pg_str.split(", ");
322            for page in pages {
323                if page.ends_with("; rel=\"next\"") {
324                    if let Some(link) = page.split_once(">;") {
325                        let url = link.0.trim_start_matches("<").to_string();
326                        if let Ok(next) = Url::parse(&url) {
327                            return Some(next);
328                        } else {
329                            log::debug!("Failed to parse next page link from response header");
330                        }
331                    } else {
332                        log::debug!("Response header link for pagination is malformed");
333                    }
334                }
335            }
336        }
337        None
338    }
339
340    async fn log_response(&self, response: Response, context: &str) {
341        if let Err(e) = response.error_for_status_ref() {
342            log::error!("{}: {e:?}", context.to_owned());
343            if let Ok(text) = response.text().await {
344                log::error!("{text}");
345            }
346        }
347    }
348}