dagger_sdk/core/
gql_client.rs

1use reqwest::Error;
2use reqwest::{Client, Url};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt::{self, Formatter};
6use std::str::FromStr;
7
8#[derive(Clone)]
9pub struct GraphQLError {
10    message: String,
11    json: Option<Vec<GraphQLErrorMessage>>,
12}
13
14// https://spec.graphql.org/June2018/#sec-Errors
15#[derive(Deserialize, Debug, Clone)]
16#[allow(dead_code)]
17pub struct GraphQLErrorMessage {
18    pub message: String,
19    locations: Option<Vec<GraphQLErrorLocation>>,
20    extensions: Option<HashMap<String, String>>,
21    path: Option<Vec<GraphQLErrorPathParam>>,
22}
23
24#[derive(Deserialize, Debug, Clone)]
25#[allow(dead_code)]
26pub struct GraphQLErrorLocation {
27    line: u32,
28    column: u32,
29}
30
31#[derive(Deserialize, Debug, Clone)]
32#[serde(untagged)]
33pub enum GraphQLErrorPathParam {
34    String(String),
35    Number(u32),
36}
37
38impl GraphQLError {
39    pub fn with_text(message: impl AsRef<str>) -> Self {
40        Self {
41            message: message.as_ref().to_string(),
42            json: None,
43        }
44    }
45
46    pub fn with_message_and_json(message: impl AsRef<str>, json: Vec<GraphQLErrorMessage>) -> Self {
47        Self {
48            message: message.as_ref().to_string(),
49            json: Some(json),
50        }
51    }
52
53    pub fn with_json(json: Vec<GraphQLErrorMessage>) -> Self {
54        Self::with_message_and_json("Look at json field for more details", json)
55    }
56
57    pub fn message(&self) -> &str {
58        &self.message
59    }
60
61    pub fn json(&self) -> Option<Vec<GraphQLErrorMessage>> {
62        self.json.clone()
63    }
64}
65
66fn format(err: &GraphQLError, f: &mut Formatter<'_>) -> fmt::Result {
67    // Print the main error message
68    writeln!(f, "\nGQLClient Error: {}", err.message)?;
69
70    // Check if query errors have been received
71    if err.json.is_none() {
72        return Ok(());
73    }
74
75    let errors = err.json.as_ref();
76
77    for err in errors.unwrap() {
78        writeln!(f, "Message: {}", err.message)?;
79    }
80
81    Ok(())
82}
83
84impl fmt::Display for GraphQLError {
85    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
86        format(self, f)
87    }
88}
89
90impl fmt::Debug for GraphQLError {
91    #[allow(clippy::needless_borrow)]
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        format(&self, f)
94    }
95}
96
97impl From<Error> for GraphQLError {
98    fn from(error: Error) -> Self {
99        Self {
100            message: error.to_string(),
101            json: None,
102        }
103    }
104}
105
106/// GQL client config
107#[derive(Clone, Debug, Deserialize, Serialize)]
108pub struct ClientConfig {
109    /// the endpoint about graphql server
110    pub endpoint: String,
111    /// gql query timeout, unit: seconds
112    pub timeout: Option<u64>,
113    /// additional request header
114    pub headers: Option<HashMap<String, String>>,
115    /// request proxy
116    pub proxy: Option<GQLProxy>,
117}
118
119/// proxy type
120#[derive(Clone, Debug, Deserialize, Serialize)]
121pub enum ProxyType {
122    Http,
123    Https,
124    All,
125}
126
127/// proxy auth, basic_auth
128#[derive(Clone, Debug, Deserialize, Serialize)]
129pub struct ProxyAuth {
130    pub username: String,
131    pub password: String,
132}
133
134/// request proxy
135#[derive(Clone, Debug, Deserialize, Serialize)]
136pub struct GQLProxy {
137    /// schema, proxy url
138    pub schema: String,
139    /// proxy type
140    pub type_: ProxyType,
141    /// auth
142    pub auth: Option<ProxyAuth>,
143}
144
145#[cfg(not(target_arch = "wasm32"))]
146impl TryFrom<GQLProxy> for reqwest::Proxy {
147    type Error = GraphQLError;
148
149    fn try_from(gql_proxy: GQLProxy) -> Result<Self, Self::Error> {
150        let proxy = match gql_proxy.type_ {
151            ProxyType::Http => reqwest::Proxy::http(gql_proxy.schema),
152            ProxyType::Https => reqwest::Proxy::https(gql_proxy.schema),
153            ProxyType::All => reqwest::Proxy::all(gql_proxy.schema),
154        }
155        .map_err(|e| Self::Error::with_text(format!("{:?}", e)))?;
156        Ok(proxy)
157    }
158}
159
160#[derive(Clone, Debug)]
161pub struct GQLClient {
162    config: ClientConfig,
163}
164
165#[derive(Serialize)]
166struct RequestBody<T: Serialize> {
167    query: String,
168    variables: T,
169}
170
171#[derive(Deserialize, Debug)]
172struct GraphQLResponse<T> {
173    data: Option<T>,
174    errors: Option<Vec<GraphQLErrorMessage>>,
175}
176
177impl GQLClient {
178    fn client(&self) -> Result<Client, GraphQLError> {
179        let mut builder = Client::builder().timeout(std::time::Duration::from_secs(
180            self.config.timeout.unwrap_or(5),
181        ));
182        if let Some(proxy) = &self.config.proxy {
183            builder = builder.proxy(proxy.clone().try_into()?);
184        }
185
186        builder
187            .build()
188            .map_err(|e| GraphQLError::with_text(format!("Can not create client: {:?}", e)))
189    }
190}
191
192impl GQLClient {
193    pub fn new(endpoint: impl AsRef<str>) -> Self {
194        Self {
195            config: ClientConfig {
196                endpoint: endpoint.as_ref().to_string(),
197                timeout: None,
198                headers: Default::default(),
199                proxy: None,
200            },
201        }
202    }
203
204    pub fn new_with_headers(
205        endpoint: impl AsRef<str>,
206        headers: HashMap<impl ToString, impl ToString>,
207    ) -> Self {
208        let _headers: HashMap<String, String> = headers
209            .iter()
210            .map(|(name, value)| (name.to_string(), value.to_string()))
211            .collect();
212        Self {
213            config: ClientConfig {
214                endpoint: endpoint.as_ref().to_string(),
215                timeout: None,
216                headers: Some(_headers),
217                proxy: None,
218            },
219        }
220    }
221
222    pub fn new_with_config(config: ClientConfig) -> Self {
223        Self { config }
224    }
225}
226
227impl GQLClient {
228    pub async fn query<K>(&self, query: &str) -> Result<Option<K>, GraphQLError>
229    where
230        K: for<'de> Deserialize<'de>,
231    {
232        self.query_with_vars::<K, ()>(query, ()).await
233    }
234
235    pub async fn query_unwrap<K>(&self, query: &str) -> Result<K, GraphQLError>
236    where
237        K: for<'de> Deserialize<'de>,
238    {
239        self.query_with_vars_unwrap::<K, ()>(query, ()).await
240    }
241
242    pub async fn query_with_vars_unwrap<K, T: Serialize>(
243        &self,
244        query: &str,
245        variables: T,
246    ) -> Result<K, GraphQLError>
247    where
248        K: for<'de> Deserialize<'de>,
249    {
250        match self.query_with_vars(query, variables).await? {
251            Some(v) => Ok(v),
252            None => Err(GraphQLError::with_text(format!(
253                "No data from graphql server({}) for this query",
254                self.config.endpoint
255            ))),
256        }
257    }
258
259    pub async fn query_with_vars<K, T: Serialize>(
260        &self,
261        query: &str,
262        variables: T,
263    ) -> Result<Option<K>, GraphQLError>
264    where
265        K: for<'de> Deserialize<'de>,
266    {
267        self.query_with_vars_by_endpoint(&self.config.endpoint, query, variables)
268            .await
269    }
270
271    async fn query_with_vars_by_endpoint<K, T: Serialize>(
272        &self,
273        endpoint: impl AsRef<str>,
274        query: &str,
275        variables: T,
276    ) -> Result<Option<K>, GraphQLError>
277    where
278        K: for<'de> Deserialize<'de>,
279    {
280        let mut times = 1;
281        let mut endpoint = endpoint.as_ref().to_string();
282        let endpoint_url = Url::from_str(&endpoint).map_err(|e| {
283            GraphQLError::with_text(format!("Wrong endpoint: {}. {:?}", endpoint, e))
284        })?;
285        let schema = endpoint_url.scheme();
286        let host = endpoint_url
287            .host()
288            .ok_or_else(|| GraphQLError::with_text(format!("Wrong endpoint: {}", endpoint)))?;
289
290        let client: Client = self.client()?;
291        let body = RequestBody {
292            query: query.to_string(),
293            variables,
294        };
295
296        loop {
297            if times > 10 {
298                return Err(GraphQLError::with_text(format!(
299                    "Many redirect location: {}",
300                    endpoint
301                )));
302            }
303
304            let mut request = client.post(&endpoint).json(&body);
305            if let Some(headers) = &self.config.headers {
306                if !headers.is_empty() {
307                    for (name, value) in headers {
308                        request = request.header(name, value);
309                    }
310                }
311            }
312
313            let raw_response = request.send().await?;
314            if let Some(location) = raw_response.headers().get(reqwest::header::LOCATION) {
315                let redirect_url = location.to_str().map_err(|e| {
316                    GraphQLError::with_text(format!(
317                        "Failed to parse response header: Location. {:?}",
318                        e
319                    ))
320                })?;
321
322                // if the response location start with http:// or https://
323                if redirect_url.starts_with("http://") || redirect_url.starts_with("https://") {
324                    times += 1;
325                    endpoint = redirect_url.to_string();
326                    continue;
327                }
328
329                // without schema
330                endpoint = if redirect_url.starts_with('/') {
331                    format!("{}://{}{}", schema, host, redirect_url)
332                } else {
333                    format!("{}://{}/{}", schema, host, redirect_url)
334                };
335                times += 1;
336                continue;
337            }
338
339            let status = raw_response.status();
340            let response_body_text = raw_response
341                .text()
342                .await
343                .map_err(|e| GraphQLError::with_text(format!("Can not get response: {:?}", e)))?;
344
345            let json: GraphQLResponse<K> =
346                serde_json::from_str(&response_body_text).map_err(|e| {
347                    GraphQLError::with_text(format!(
348                        "Failed to parse response: {:?}. The response body is: {}",
349                        e, response_body_text
350                    ))
351                })?;
352
353            if !status.is_success() {
354                return Err(GraphQLError::with_message_and_json(
355                    format!("The response is [{}]", status.as_u16()),
356                    json.errors.unwrap_or_default(),
357                ));
358            }
359
360            // Check if error messages have been received
361            if json.errors.is_some() {
362                return Err(GraphQLError::with_json(json.errors.unwrap_or_default()));
363            }
364            if json.data.is_none() {
365                tracing::warn!(
366                    target = "gql-client",
367                    response_text = response_body_text,
368                    "The deserialized data is none, the response",
369                );
370            }
371
372            return Ok(json.data);
373        }
374    }
375}