binstalk_git_repo_api/gh_api_client/
error.rs

1use std::{error, fmt, io, time::Duration};
2
3use binstalk_downloader::remote;
4use compact_str::{CompactString, ToCompactString};
5use serde::{de::Deserializer, Deserialize};
6use thiserror::Error as ThisError;
7
8#[derive(ThisError, Debug)]
9#[error("Context: '{context}', err: '{err}'")]
10pub struct GhApiContextError {
11    context: CompactString,
12    #[source]
13    err: GhApiError,
14}
15
16#[derive(ThisError, Debug)]
17#[non_exhaustive]
18pub enum GhApiError {
19    #[error("IO Error: {0}")]
20    Io(#[from] io::Error),
21
22    #[error("Remote Error: {0}")]
23    Remote(#[from] remote::Error),
24
25    #[error("Failed to parse url: {0}")]
26    InvalidUrl(#[from] url::ParseError),
27
28    /// A wrapped error providing the context the error is about.
29    #[error(transparent)]
30    Context(Box<GhApiContextError>),
31
32    #[error("Remote failed to process GraphQL query: {0}")]
33    GraphQLErrors(GhGraphQLErrors),
34
35    #[error("Hit rate-limit, retry after {retry_after:?}")]
36    RateLimit { retry_after: Option<Duration> },
37
38    #[error("Corresponding resource is not found")]
39    NotFound,
40
41    #[error("Does not have permission to access the API")]
42    Unauthorized,
43}
44
45impl GhApiError {
46    /// Attach context to [`GhApiError`]
47    pub fn context(self, context: impl fmt::Display) -> Self {
48        use GhApiError::*;
49
50        if matches!(self, RateLimit { .. } | NotFound | Unauthorized) {
51            self
52        } else {
53            Self::Context(Box::new(GhApiContextError {
54                context: context.to_compact_string(),
55                err: self,
56            }))
57        }
58    }
59}
60
61impl From<GhGraphQLErrors> for GhApiError {
62    fn from(e: GhGraphQLErrors) -> Self {
63        if e.is_rate_limited() {
64            Self::RateLimit { retry_after: None }
65        } else if e.is_not_found_error() {
66            Self::NotFound
67        } else {
68            Self::GraphQLErrors(e)
69        }
70    }
71}
72
73#[derive(Debug, Deserialize)]
74pub struct GhGraphQLErrors(Box<[GraphQLError]>);
75
76impl GhGraphQLErrors {
77    fn is_rate_limited(&self) -> bool {
78        self.0
79            .iter()
80            .any(|error| matches!(error.error_type, GraphQLErrorType::RateLimited))
81    }
82
83    fn is_not_found_error(&self) -> bool {
84        self.0
85            .iter()
86            .any(|error| matches!(&error.error_type, GraphQLErrorType::Other(error_type) if *error_type == "NOT_FOUND"))
87    }
88}
89
90impl error::Error for GhGraphQLErrors {}
91
92impl fmt::Display for GhGraphQLErrors {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let last_error_index = self.0.len() - 1;
95
96        for (i, error) in self.0.iter().enumerate() {
97            write!(
98                f,
99                "type: '{error_type}', msg: '{msg}'",
100                error_type = error.error_type,
101                msg = error.message,
102            )?;
103
104            for location in error.locations.as_deref().into_iter().flatten() {
105                write!(
106                    f,
107                    ", occured on query line {line} col {col}",
108                    line = location.line,
109                    col = location.column
110                )?;
111            }
112
113            for (k, v) in &error.others {
114                write!(f, ", {k}: {v}")?;
115            }
116
117            if i < last_error_index {
118                f.write_str("\n")?;
119            }
120        }
121
122        Ok(())
123    }
124}
125
126#[derive(Debug, Deserialize)]
127struct GraphQLError {
128    message: CompactString,
129    locations: Option<Box<[GraphQLLocation]>>,
130
131    #[serde(rename = "type")]
132    error_type: GraphQLErrorType,
133
134    #[serde(flatten, with = "tuple_vec_map")]
135    others: Vec<(CompactString, serde_json::Value)>,
136}
137
138#[derive(Debug)]
139pub(super) enum GraphQLErrorType {
140    RateLimited,
141    Other(CompactString),
142}
143
144impl fmt::Display for GraphQLErrorType {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.write_str(match self {
147            GraphQLErrorType::RateLimited => "RATE_LIMITED",
148            GraphQLErrorType::Other(s) => s,
149        })
150    }
151}
152
153impl<'de> Deserialize<'de> for GraphQLErrorType {
154    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
155    where
156        D: Deserializer<'de>,
157    {
158        let s = CompactString::deserialize(deserializer)?;
159        Ok(match &*s {
160            "RATE_LIMITED" => GraphQLErrorType::RateLimited,
161            _ => GraphQLErrorType::Other(s),
162        })
163    }
164}
165
166#[derive(Debug, Deserialize)]
167struct GraphQLLocation {
168    line: u64,
169    column: u64,
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175    use serde::de::value::{BorrowedStrDeserializer, Error};
176
177    macro_rules! assert_matches {
178        ($expression:expr, $pattern:pat $(if $guard:expr)? $(,)?) => {
179            match $expression {
180                $pattern $(if $guard)? => true,
181                expr => {
182                    panic!(
183                        "assertion failed: `{expr:?}` does not match `{}`",
184                        stringify!($pattern $(if $guard)?)
185                    )
186                }
187            }
188        }
189    }
190
191    #[test]
192    fn test_graph_ql_error_type() {
193        let deserialize = |input: &str| {
194            GraphQLErrorType::deserialize(BorrowedStrDeserializer::<'_, Error>::new(input)).unwrap()
195        };
196
197        assert_matches!(deserialize("RATE_LIMITED"), GraphQLErrorType::RateLimited);
198        assert_matches!(
199            deserialize("rATE_LIMITED"),
200            GraphQLErrorType::Other(val) if val == CompactString::const_new("rATE_LIMITED")
201        );
202    }
203}