discovery_connect/api/
query.rs

1// Copyright 2023 Ikerian AG
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use graphql_client::GraphQLQuery;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex};
19
20/// This unit is responsible for sending asynchronous GraphQL queries to the Discovery server.
21///
22/// Asynchronously sends a GraphQL query to a server using the provided `QueryClient` and query variables.
23/// This function handles the process of serializing, posting a GraphQL query and deserializing its response.
24///
25/// # Arguments
26/// * `qc` - An `Arc<QueryClient>` instance, used to maintain state across async calls and manage the HTTP client.
27/// * `variables` - Variables required for the GraphQL query, as defined by the `GraphQLQuery` trait.
28///
29/// # Returns
30/// An implementation of `Future`, which when awaited, yields a `Result` containing either the response data
31/// on success (`T::ResponseData`), or an `reqwest::Error` on failure.
32///
33/// # Type Constraints
34/// * `T` - Represents the GraphQL query and must implement `GraphQLQuery`.
35/// * `<T as GraphQLQuery>::Variables` - Must implement `Clone` trait.
36///
37/// # Errors
38/// Returns an `Err` of type `reqwest::Error` if the request fails or if the server
39/// response status is not OK.
40///
41/// # Examples
42/// ```
43/// use std::sync::Arc;
44/// use std::time::Duration;
45/// use crate::discovery_connect::file::{create_file, CreateFile};
46/// use crate::discovery_connect::{post_query, QueryClient};
47///
48/// async fn query_example() {
49///     let timeout = Duration::from_secs(1);
50///     let query_client = Arc::new(QueryClient::new(
51///         "https://api.example.discovery.retinai.com",
52///         "client_id",
53///         "client_secret",
54///         "user@example",
55///         "password123",
56///         timeout,));
57///    let input = create_file::CreateFileInput {
58///        uuid: None,
59///        filename: "example.zip".to_string(),
60///        tags: Some(vec![]),
61///        remarks: Some(serde_json::Value::Object(serde_json::Map::new())),
62///        overwrite: Some(serde_json::Value::Object(serde_json::Map::new())),
63///        workbook_uuid: Some("123e4567-e89b-12d3-a456-426614174000".to_string()),
64///    };
65///    let variables = create_file::Variables { input };
66///    post_query::<CreateFile>(query_client, variables).await;
67/// }
68/// ```
69pub async fn post_query<'a, T: GraphQLQuery + 'a>(
70    qc: Arc<QueryClient>,
71    variables: <T as GraphQLQuery>::Variables,
72) -> Result<T::ResponseData, reqwest::Error>
73where
74    <T as GraphQLQuery>::Variables: Clone,
75{
76    async fn _post_query<'a, T: GraphQLQuery + 'a>(
77        qc: Arc<QueryClient>,
78        url: String,
79        access_token: String,
80        request_body: graphql_client::QueryBody<<T as GraphQLQuery>::Variables>,
81    ) -> Result<T::ResponseData, reqwest::Error> {
82        let result = qc
83            .client
84            .post(url)
85            .bearer_auth(access_token)
86            .json(&request_body);
87
88        let result = result.send().await;
89        match result {
90            Ok(result) => {
91                if result.status() != reqwest::StatusCode::OK {
92                    let e = result.error_for_status().unwrap_err();
93                    eprintln!("Error: {:?}", e);
94                    return Err(e);
95                }
96                result.text().await.map(|text| {
97                    let json: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
98                    let data = json.get("data");
99                    let data = serde_json::to_string(&data).unwrap_or_default();
100                    match serde_json::from_str(&data) {
101                        Ok(r) => r,
102                        Err(e) => {
103                            eprintln!("Error: {:?}", e);
104                            // TODO: replace with proper error handling
105                            // temporary hack to get this to compile
106                            panic!("Failed to parse response");
107                        }
108                    }
109                })
110            }
111            Err(e) => {
112                eprintln!("Error: {:?}", e);
113                Err(e)
114            }
115        }
116    }
117
118    with_fresh_token(qc.clone(), move || {
119        let url = qc.config.lock().unwrap().url.clone();
120        let url = format!("{}/api", url).clone();
121
122        let access_token: String = match qc.config.lock() {
123            Ok(config) => config.access_token.clone().unwrap(),
124            Err(e) => {
125                panic!("Mutex poisoned: {:?}", e);
126            }
127        };
128        let request_body = T::build_query(variables.clone());
129        Box::pin(_post_query::<'a, T>(
130            qc.clone(),
131            url,
132            access_token,
133            request_body,
134        ))
135    })
136    .await
137}
138
139/// Macro to define types for the Disco GraphQL schema.
140///
141/// This macro simplifies the creation of types associated with GraphQL queries by automatically
142/// deriving necessary traits and setting up schema configurations. It is intended to be used
143/// in place of the `#[graphql]` attribute macro directly, providing a more concise syntax and
144/// ensuring consistent configuration across different query types.
145///
146/// The macro generates a struct with the given identifier that represents a GraphQL query or mutation,
147/// along with associated traits like `Debug`, `Default`, `PartialEq`, and `Serialize`. It also
148/// handles custom scalar types, particularly mapping the GraphQL "JSON" type to the Rust equivalent.
149///
150/// # Usage
151///
152/// `disco_api!` macro is used by passing the name of the struct to be generated and the GraphQL
153/// operation name. The macro expands to a struct definition with the `#[derive(GraphQLQuery)]`
154/// attribute, along with necessary configurations.
155///
156/// # Example
157///
158/// ```
159/// # use graphql_client::GraphQLQuery;
160/// # use serde_json::Value as JSON;
161/// # use discovery_connect::disco_api;
162/// // TODO: the following line fails in doc tests but works in regular code:
163/// // disco_api!(CreateFile, "createFile", "src/api/graphql/CreateFile.graphql");
164/// ```
165///
166/// This example creates a `CreateFile` struct, which is tied to the "createFile" GraphQL operation
167/// defined in file "src/api/graphql/CreateFile.graphql".
168/// The struct will automatically derive `Debug`, `Default`, `PartialEq`, and `Serialize` traits, and
169/// will use custom scalar configuration for handling JSON types.
170///
171/// # Parameters
172///
173/// * `$name`: The struct name to be generated.
174/// * `$operation_name`: The name of the GraphQL operation this struct represents.
175#[macro_export]
176macro_rules! disco_api {
177    ($name:ident, $operation_name:expr, $query_path:expr) => {
178        #[derive(GraphQLQuery)]
179        #[graphql(
180                    query_path = $query_path,
181                    schema_path = "src/api/graphql/schema.graphql",
182                    response_derives = "Clone,Debug,Default,PartialEq,Deserialize,Serialize",
183                    variables_derives = "Clone",
184                    operation_name = $operation_name,
185                    custom_scalars(
186                        graphql_type = "JSON",
187                        rust_type = "serde_json::Value",
188                        serialize_with = "graphql_client::serialization::AsJson",
189                        deserialize_with = "graphql_client::deprecation::FromJson"
190                    )
191                )]
192        pub struct $name;
193    };
194}
195
196/// Represents the authentication configuration for a service.
197///
198/// This structure holds various credentials and tokens necessary for authenticating
199/// with a specific service. It includes the service's URL, user credentials, and tokens
200/// used for accessing protected resources.
201///
202/// # Fields
203///
204/// * `url` - A `String` representing the URL of the service to which queries will be sent.
205/// * `email` - A `String` representing the user's email address used for authentication.
206/// * `region` - A `Region` enum representing the region of the service.
207/// * `password` - A `String` representing the user's password used for authentication.
208/// * `access_token` - An `Option<String>` that holds the access token if available.
209///   This token is used to authenticate queries. It can be `None` if the token has not been acquired yet.
210/// * `refresh_token` - An `Option<String>` that holds the refresh token if available.
211///   This token is used to obtain a new access token when the current one expires.
212///   It can be `None` if the token has not been acquired yet.
213///
214/// # Examples
215///
216/// ```
217/// use discovery_connect::{QueryConfig};
218///
219/// let query_config = QueryConfig {
220///     client_id: "client_id".to_string(),
221///     client_secret: "client_secret".to_string(),
222///     url: "https://api.europe.discovery.retinai.com".to_string(),
223///     email: "user@example.com".to_string(),
224///     password: "password123".to_string(),
225///     access_token: None,
226///     refresh_token: None,
227/// };
228/// ```
229#[derive(Clone, Debug)]
230pub struct QueryConfig {
231    pub url: String,
232    pub client_id: String,
233    pub client_secret: String,
234    pub email: String,
235    pub password: String,
236    pub access_token: Option<String>,
237    pub refresh_token: Option<String>,
238}
239
240/// Encapsulates a client for performing authenticated queries to a service.
241///
242/// # Fields
243///
244/// * `client` - An instance of `reqwest::Client` used for making HTTP requests.
245/// * `config` - A `Mutex<QueryConfig>` that holds the authentication configuration.
246///   The `Mutex` ensures that the configuration can be safely accessed and modified
247///   across multiple threads.
248///
249/// # Examples
250///
251/// ```
252/// use discovery_connect::{QueryClient, QueryConfig};
253/// use std::sync::Mutex;
254///
255/// let query_client = QueryClient {
256///     client: reqwest::Client::new(),
257///     config: Mutex::new(QueryConfig {
258///         client_id: "client_id".to_string(),
259///         client_secret: "client_secret".to_string(),
260///         url: "https://api.europe.discovery.retinai.com".to_string(),
261///         email: "user@example.com".to_string(),
262///         password: "password123".to_string(),
263///         access_token: None,
264///         refresh_token: None,
265///     }),
266/// };
267/// ```
268#[derive(Debug)]
269pub struct QueryClient {
270    pub config: Mutex<QueryConfig>,
271    pub client: reqwest::Client,
272}
273
274impl QueryClient {
275    pub fn new(
276        url: &str,
277        client_id: &str,
278        client_secret: &str,
279        email: &str,
280        password: &str,
281        timeout: std::time::Duration,
282    ) -> QueryClient {
283        let client = reqwest::Client::builder()
284            .timeout(timeout)
285            .build()
286            .unwrap_or_default();
287
288        QueryClient {
289            client,
290            config: Mutex::new(QueryConfig {
291                url: url.to_string(),
292                client_id: client_id.to_string(),
293                client_secret: client_secret.to_string(),
294                email: email.to_string(),
295                password: password.to_string(),
296                access_token: None,
297                refresh_token: None,
298            }),
299        }
300    }
301}
302
303/// Executes a provided asynchronous function, handling token authentication.
304///
305/// This function takes a configuration object and an asynchronous function `func`.
306/// It attempts to execute `func`. If `func` fails due to an unauthorized error,
307/// `with_fresh_token` tries to refresh the access token using the refresh token in the configuration.
308/// If refreshing the token also fails with an unauthorized error, it attempts to log in again
309/// with the credentials in the configuration to obtain new tokens.
310///
311/// # Arguments
312///
313/// * `config` - A shared, thread-safe reference (`Arc`) to an `QueryConfig` object
314///   containing authentication details like URL, email, password, and tokens.
315/// * `func` - The asynchronous function to be executed. This function should return a `Future`
316///   that resolves to a `Result<R, reqwest::Error>`. `R` is the expected successful return type of the function.
317///
318/// # Returns
319///
320/// This function returns a `Future` that resolves to a `Result<R, reqwest::Error>`.
321/// `R` is the return type of the provided asynchronous function `func`.
322/// If `func` executes successfully, its result is returned. If `func` fails due to an unauthorized error,
323/// a token refresh or re-login is attempted, and `func` is retried.
324/// If the token refresh or re-login fails, or if `func` fails due to any other error,
325/// the error is propagated as the return value.
326///
327/// # Examples
328///
329/// ```
330/// use discovery_connect::{QueryConfig, QueryClient, with_fresh_token};
331/// use std::sync::Arc;
332/// use std::sync::Mutex;
333///
334/// async fn example() {
335///     let config = Arc::new(QueryClient {
336///         client: reqwest::Client::new(),
337///         config: Mutex::new(QueryConfig {
338///             client_id: "client_id".to_string(),
339///             client_secret: "client_secret".to_string(),
340///             url: "https://api.europe.discovery.retinai.com".to_string(),
341///             email: "user@example".to_string(),
342///             password: "password123".to_string(),
343///             access_token: None,
344///             refresh_token: None,
345///         })
346///     });
347///     let result = with_fresh_token(config, move || {
348///         // do something that requires authentication here
349///         // ...
350///         Box::pin(async move { Ok(()) })
351///     }).await;
352/// }
353/// ```
354pub async fn with_fresh_token<'a, F, R>(qc: Arc<QueryClient>, func: F) -> Result<R, reqwest::Error>
355where
356    F: Fn() -> Pin<Box<dyn Future<Output = Result<R, reqwest::Error>> + 'a>> + 'a,
357    R: 'a,
358{
359    {
360        // login if there is no refresh token
361        let config = qc.config.lock().unwrap().clone();
362        let access_token = &config.access_token;
363        let refresh_token = &config.refresh_token;
364        if access_token.is_none() || refresh_token.is_none() {
365            match login(qc.clone(), config).await {
366                Ok(config) => {
367                    let mut qc_config = qc.config.lock().unwrap();
368                    qc_config.access_token = config.access_token;
369                    qc_config.refresh_token = config.refresh_token;
370                }
371                Err(e) => {
372                    eprintln!("Error: {:?}", e);
373                    return Err(e);
374                }
375            }
376        }
377    }
378
379    // execute the function and check for authorization error
380    match func().await {
381        Ok(result) => Ok(result),
382        Err(e) => {
383            if e.status() == Some(reqwest::StatusCode::UNAUTHORIZED) {
384                {
385                    // Try to refresh the access token and retry
386                    let config = qc.config.lock().unwrap().clone();
387                    let refresh_token = &config.refresh_token;
388                    match super::auth::access_token(
389                        &qc.client,
390                        &config.url,
391                        refresh_token.as_ref().unwrap(),
392                    )
393                    .await
394                    {
395                        Ok(auth) => {
396                            let mut qc_config = qc.config.lock().unwrap();
397                            qc_config.access_token = Some(auth.access_token.clone());
398                            qc_config.refresh_token = Some(auth.refresh_token.clone());
399                        }
400                        Err(e) => {
401                            eprintln!("Error: {:?}", e);
402                            // failed to refresh tokens, clear to force a fresh login
403                            let mut qc_config = qc.config.lock().unwrap();
404                            qc_config.access_token = None;
405                            qc_config.refresh_token = None;
406                            return Err(e);
407                        }
408                    };
409                }
410
411                // tokens refreshed, retry the function
412                func().await
413            } else {
414                eprintln!("Error: {:?}", e);
415                Err(e)
416            }
417        }
418    }
419}
420
421/// Performs a basic password login and returns the authentication response.
422///
423/// # Arguments
424///
425/// * `qc` - A shared, thread-safe reference (`Arc`) to an `QueryClient` object
426///   containing authentication details like URL, email, password, and tokens.
427/// * `config` - A shared, thread-safe reference (`MutexGuard`) to an `QueryConfig` object
428///   containing authentication details like URL, email, password, and tokens.
429///
430/// # Returns
431///
432/// If the login is successful, the `AuthResponse` is returned.
433/// If the login fails, an `reqwest::Error` is returned.
434pub async fn login(
435    qc: Arc<QueryClient>,
436    config: QueryConfig,
437) -> Result<QueryConfig, reqwest::Error> {
438    match super::auth::login(
439        &qc.client,
440        &config.url,
441        &config.client_id,
442        &config.client_secret,
443        &config.email,
444        &config.password,
445        None,
446    )
447    .await
448    {
449        Ok(auth) => Ok(QueryConfig {
450            access_token: Some(auth.access_token.clone()),
451            refresh_token: Some(auth.refresh_token.clone()),
452            ..config
453        }),
454        Err(e) => Err(e),
455    }
456}