Skip to main content

dagger_sdk/core/
gql_client.rs

1use reqwest::Error;
2use reqwest::{Client, Url};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt::{self, Formatter};
6use std::str::FromStr;
7
8#[derive(Clone)]
9pub struct GraphQLError {
10    message: String,
11    json: Option<Vec<GraphQLErrorMessage>>,
12}
13
14#[derive(Deserialize, Debug, Clone)]
15#[serde(tag = "_type")]
16pub enum GraphQlExtension {
17    #[serde(rename = "EXEC_ERROR")]
18    ExecError {
19        cmd: Vec<String>,
20        #[serde(rename(deserialize = "exitCode"))]
21        exit_code: i32,
22        stderr: String,
23        stdout: String,
24    },
25    #[serde(other)]
26    Other,
27}
28
29// https://spec.graphql.org/June2018/#sec-Errors
30#[derive(Deserialize, Debug, Clone)]
31#[allow(dead_code)]
32pub struct GraphQLErrorMessage {
33    pub message: String,
34    pub locations: Option<Vec<GraphQLErrorLocation>>,
35    pub extensions: Option<GraphQlExtension>,
36    pub path: Option<Vec<GraphQLErrorPathParam>>,
37}
38
39#[derive(Deserialize, Debug, Clone)]
40#[allow(dead_code)]
41pub struct GraphQLErrorLocation {
42    line: u32,
43    column: u32,
44}
45
46#[derive(Deserialize, Debug, Clone)]
47#[serde(untagged)]
48pub enum GraphQLErrorPathParam {
49    String(String),
50    Number(u32),
51}
52
53impl GraphQLError {
54    pub fn with_text(message: impl AsRef<str>) -> Self {
55        Self {
56            message: message.as_ref().to_string(),
57            json: None,
58        }
59    }
60
61    pub fn with_message_and_json(message: impl AsRef<str>, json: Vec<GraphQLErrorMessage>) -> Self {
62        Self {
63            message: message.as_ref().to_string(),
64            json: Some(json),
65        }
66    }
67
68    pub fn with_json(json: Vec<GraphQLErrorMessage>) -> Self {
69        Self::with_message_and_json("Look at json field for more details", json)
70    }
71
72    pub fn message(&self) -> &str {
73        &self.message
74    }
75
76    pub fn json(&self) -> Option<Vec<GraphQLErrorMessage>> {
77        self.json.clone()
78    }
79}
80
81fn format(err: &GraphQLError, f: &mut Formatter<'_>) -> fmt::Result {
82    // Print the main error message
83    writeln!(f, "\nGQLClient Error: {}", err.message)?;
84
85    // Check if query errors have been received
86    if err.json.is_none() {
87        return Ok(());
88    }
89
90    let errors = err.json.as_ref();
91
92    for err in errors.unwrap() {
93        writeln!(f, "Message: {}", err.message)?;
94    }
95
96    Ok(())
97}
98
99impl fmt::Display for GraphQLError {
100    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
101        format(self, f)
102    }
103}
104
105impl fmt::Debug for GraphQLError {
106    #[allow(clippy::needless_borrow)]
107    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108        format(&self, f)
109    }
110}
111
112impl From<Error> for GraphQLError {
113    fn from(error: Error) -> Self {
114        Self {
115            message: error.to_string(),
116            json: None,
117        }
118    }
119}
120
121/// GQL client config
122#[derive(Clone, Debug, Deserialize, Serialize)]
123pub struct ClientConfig {
124    /// the endpoint about graphql server
125    pub endpoint: String,
126    /// gql query timeout, unit: milliseconds
127    pub execute_timeout_ms: Option<u64>,
128    /// gql connect timeout, unit: seconds
129    pub connect_timeout_ms: Option<u64>,
130    /// additional request header
131    pub headers: Option<HashMap<String, String>>,
132    /// request proxy
133    pub proxy: Option<GQLProxy>,
134}
135
136/// proxy type
137#[derive(Clone, Debug, Deserialize, Serialize)]
138pub enum ProxyType {
139    Http,
140    Https,
141    All,
142}
143
144/// proxy auth, basic_auth
145#[derive(Clone, Debug, Deserialize, Serialize)]
146pub struct ProxyAuth {
147    pub username: String,
148    pub password: String,
149}
150
151/// request proxy
152#[derive(Clone, Debug, Deserialize, Serialize)]
153pub struct GQLProxy {
154    /// schema, proxy url
155    pub schema: String,
156    /// proxy type
157    pub type_: ProxyType,
158    /// auth
159    pub auth: Option<ProxyAuth>,
160}
161
162#[cfg(not(target_arch = "wasm32"))]
163impl TryFrom<GQLProxy> for reqwest::Proxy {
164    type Error = GraphQLError;
165
166    fn try_from(gql_proxy: GQLProxy) -> Result<Self, Self::Error> {
167        let proxy = match gql_proxy.type_ {
168            ProxyType::Http => reqwest::Proxy::http(gql_proxy.schema),
169            ProxyType::Https => reqwest::Proxy::https(gql_proxy.schema),
170            ProxyType::All => reqwest::Proxy::all(gql_proxy.schema),
171        }
172        .map_err(|e| Self::Error::with_text(format!("{:?}", e)))?;
173        Ok(proxy)
174    }
175}
176
177#[derive(Clone, Debug)]
178pub struct GQLClient {
179    config: ClientConfig,
180}
181
182#[derive(Serialize)]
183struct RequestBody<T: Serialize> {
184    query: String,
185    variables: T,
186}
187
188#[derive(Deserialize, Debug)]
189struct GraphQLResponse<T> {
190    data: Option<T>,
191    errors: Option<Vec<GraphQLErrorMessage>>,
192}
193
194impl GQLClient {
195    fn client(&self) -> Result<Client, GraphQLError> {
196        let mut builder = Client::builder();
197
198        if let Some(connect_timeout_ms) = self.config.connect_timeout_ms {
199            builder = builder.connect_timeout(std::time::Duration::from_millis(connect_timeout_ms));
200        }
201
202        if let Some(execute_timeout_ms) = self.config.execute_timeout_ms {
203            builder = builder.timeout(std::time::Duration::from_millis(execute_timeout_ms));
204        }
205
206        if let Some(proxy) = &self.config.proxy {
207            builder = builder.proxy(proxy.clone().try_into()?);
208        }
209
210        builder
211            .build()
212            .map_err(|e| GraphQLError::with_text(format!("Can not create client: {:?}", e)))
213    }
214}
215
216impl GQLClient {
217    pub fn new(endpoint: impl AsRef<str>) -> Self {
218        Self {
219            config: ClientConfig {
220                endpoint: endpoint.as_ref().to_string(),
221                connect_timeout_ms: None,
222                execute_timeout_ms: None,
223                headers: Default::default(),
224                proxy: None,
225            },
226        }
227    }
228
229    pub fn new_with_headers(
230        endpoint: impl AsRef<str>,
231        headers: HashMap<impl ToString, impl ToString>,
232    ) -> Self {
233        let _headers: HashMap<String, String> = headers
234            .iter()
235            .map(|(name, value)| (name.to_string(), value.to_string()))
236            .collect();
237        Self {
238            config: ClientConfig {
239                endpoint: endpoint.as_ref().to_string(),
240                connect_timeout_ms: None,
241                execute_timeout_ms: None,
242                headers: Some(_headers),
243                proxy: None,
244            },
245        }
246    }
247
248    pub fn new_with_config(config: ClientConfig) -> Self {
249        Self { config }
250    }
251}
252
253impl GQLClient {
254    pub async fn query<K>(&self, query: &str) -> Result<Option<K>, GraphQLError>
255    where
256        K: for<'de> Deserialize<'de>,
257    {
258        self.query_with_vars::<K, ()>(query, ()).await
259    }
260
261    pub async fn query_unwrap<K>(&self, query: &str) -> Result<K, GraphQLError>
262    where
263        K: for<'de> Deserialize<'de>,
264    {
265        self.query_with_vars_unwrap::<K, ()>(query, ()).await
266    }
267
268    pub async fn query_with_vars_unwrap<K, T: Serialize>(
269        &self,
270        query: &str,
271        variables: T,
272    ) -> Result<K, GraphQLError>
273    where
274        K: for<'de> Deserialize<'de>,
275    {
276        match self.query_with_vars(query, variables).await? {
277            Some(v) => Ok(v),
278            None => Err(GraphQLError::with_text(format!(
279                "No data from graphql server({}) for this query",
280                self.config.endpoint
281            ))),
282        }
283    }
284
285    pub async fn query_with_vars<K, T: Serialize>(
286        &self,
287        query: &str,
288        variables: T,
289    ) -> Result<Option<K>, GraphQLError>
290    where
291        K: for<'de> Deserialize<'de>,
292    {
293        self.query_with_vars_by_endpoint(&self.config.endpoint, query, variables)
294            .await
295    }
296
297    async fn query_with_vars_by_endpoint<K, T: Serialize>(
298        &self,
299        endpoint: impl AsRef<str>,
300        query: &str,
301        variables: T,
302    ) -> Result<Option<K>, GraphQLError>
303    where
304        K: for<'de> Deserialize<'de>,
305    {
306        let mut times = 1;
307        let mut endpoint = endpoint.as_ref().to_string();
308        let endpoint_url = Url::from_str(&endpoint).map_err(|e| {
309            GraphQLError::with_text(format!("Wrong endpoint: {}. {:?}", endpoint, e))
310        })?;
311        let schema = endpoint_url.scheme();
312        let host = endpoint_url
313            .host()
314            .ok_or_else(|| GraphQLError::with_text(format!("Wrong endpoint: {}", endpoint)))?;
315
316        let client: Client = self.client()?;
317        let body = RequestBody {
318            query: query.to_string(),
319            variables,
320        };
321
322        loop {
323            if times > 10 {
324                return Err(GraphQLError::with_text(format!(
325                    "Many redirect location: {}",
326                    endpoint
327                )));
328            }
329
330            let mut request = client.post(&endpoint).json(&body);
331            if let Some(headers) = &self.config.headers {
332                if !headers.is_empty() {
333                    for (name, value) in headers {
334                        request = request.header(name, value);
335                    }
336                }
337            }
338
339            let raw_response = request.send().await?;
340            if let Some(location) = raw_response.headers().get(reqwest::header::LOCATION) {
341                let redirect_url = location.to_str().map_err(|e| {
342                    GraphQLError::with_text(format!(
343                        "Failed to parse response header: Location. {:?}",
344                        e
345                    ))
346                })?;
347
348                // if the response location start with http:// or https://
349                if redirect_url.starts_with("http://") || redirect_url.starts_with("https://") {
350                    times += 1;
351                    endpoint = redirect_url.to_string();
352                    continue;
353                }
354
355                // without schema
356                endpoint = if redirect_url.starts_with('/') {
357                    format!("{}://{}{}", schema, host, redirect_url)
358                } else {
359                    format!("{}://{}/{}", schema, host, redirect_url)
360                };
361                times += 1;
362                continue;
363            }
364
365            let status = raw_response.status();
366            let response_body_text = raw_response
367                .text()
368                .await
369                .map_err(|e| GraphQLError::with_text(format!("Can not get response: {:?}", e)))?;
370
371            let json: GraphQLResponse<K> =
372                serde_json::from_str(&response_body_text).map_err(|e| {
373                    GraphQLError::with_text(format!(
374                        "Failed to parse response: {:?}. The response body is: {}",
375                        e, response_body_text
376                    ))
377                })?;
378
379            if !status.is_success() {
380                return Err(GraphQLError::with_message_and_json(
381                    format!("The response is [{}]", status.as_u16()),
382                    json.errors.unwrap_or_default(),
383                ));
384            }
385
386            // Check if error messages have been received
387            if json.errors.is_some() {
388                return Err(GraphQLError::with_json(json.errors.unwrap_or_default()));
389            }
390            if json.data.is_none() {
391                tracing::warn!(
392                    target = "gql-client",
393                    response_text = response_body_text,
394                    "The deserialized data is none, the response",
395                );
396            }
397
398            return Ok(json.data);
399        }
400    }
401}