discovery_connect/api/query.rs
1// Copyright 2023 Ikerian AG
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use graphql_client::GraphQLQuery;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex};
19
20/// This unit is responsible for sending asynchronous GraphQL queries to the Discovery server.
21///
22/// Asynchronously sends a GraphQL query to a server using the provided `QueryClient` and query variables.
23/// This function handles the process of serializing, posting a GraphQL query and deserializing its response.
24///
25/// # Arguments
26/// * `qc` - An `Arc<QueryClient>` instance, used to maintain state across async calls and manage the HTTP client.
27/// * `variables` - Variables required for the GraphQL query, as defined by the `GraphQLQuery` trait.
28///
29/// # Returns
30/// An implementation of `Future`, which when awaited, yields a `Result` containing either the response data
31/// on success (`T::ResponseData`), or an `reqwest::Error` on failure.
32///
33/// # Type Constraints
34/// * `T` - Represents the GraphQL query and must implement `GraphQLQuery`.
35/// * `<T as GraphQLQuery>::Variables` - Must implement `Clone` trait.
36///
37/// # Errors
38/// Returns an `Err` of type `reqwest::Error` if the request fails or if the server
39/// response status is not OK.
40///
41/// # Examples
42/// ```
43/// use std::sync::Arc;
44/// use std::time::Duration;
45/// use crate::discovery_connect::file::{create_file, CreateFile};
46/// use crate::discovery_connect::{post_query, QueryClient};
47///
48/// async fn query_example() {
49/// let timeout = Duration::from_secs(1);
50/// let query_client = Arc::new(QueryClient::new(
51/// "https://api.example.discovery.retinai.com",
52/// "client_id",
53/// "client_secret",
54/// "user@example",
55/// "password123",
56/// timeout,));
57/// let input = create_file::CreateFileInput {
58/// uuid: None,
59/// filename: "example.zip".to_string(),
60/// tags: Some(vec![]),
61/// remarks: Some(serde_json::Value::Object(serde_json::Map::new())),
62/// overwrite: Some(serde_json::Value::Object(serde_json::Map::new())),
63/// workbook_uuid: Some("123e4567-e89b-12d3-a456-426614174000".to_string()),
64/// };
65/// let variables = create_file::Variables { input };
66/// post_query::<CreateFile>(query_client, variables).await;
67/// }
68/// ```
69pub async fn post_query<'a, T: GraphQLQuery + 'a>(
70 qc: Arc<QueryClient>,
71 variables: <T as GraphQLQuery>::Variables,
72) -> Result<T::ResponseData, reqwest::Error>
73where
74 <T as GraphQLQuery>::Variables: Clone,
75{
76 async fn _post_query<'a, T: GraphQLQuery + 'a>(
77 qc: Arc<QueryClient>,
78 url: String,
79 access_token: String,
80 request_body: graphql_client::QueryBody<<T as GraphQLQuery>::Variables>,
81 ) -> Result<T::ResponseData, reqwest::Error> {
82 let result = qc
83 .client
84 .post(url)
85 .bearer_auth(access_token)
86 .json(&request_body);
87
88 let result = result.send().await;
89 match result {
90 Ok(result) => {
91 if result.status() != reqwest::StatusCode::OK {
92 let e = result.error_for_status().unwrap_err();
93 eprintln!("Error: {:?}", e);
94 return Err(e);
95 }
96 result.text().await.map(|text| {
97 let json: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
98 let data = json.get("data");
99 let data = serde_json::to_string(&data).unwrap_or_default();
100 match serde_json::from_str(&data) {
101 Ok(r) => r,
102 Err(e) => {
103 eprintln!("Error: {:?}", e);
104 // TODO: replace with proper error handling
105 // temporary hack to get this to compile
106 panic!("Failed to parse response");
107 }
108 }
109 })
110 }
111 Err(e) => {
112 eprintln!("Error: {:?}", e);
113 Err(e)
114 }
115 }
116 }
117
118 with_fresh_token(qc.clone(), move || {
119 let url = qc.config.lock().unwrap().url.clone();
120 let url = format!("{}/api", url).clone();
121
122 let access_token: String = match qc.config.lock() {
123 Ok(config) => config.access_token.clone().unwrap(),
124 Err(e) => {
125 panic!("Mutex poisoned: {:?}", e);
126 }
127 };
128 let request_body = T::build_query(variables.clone());
129 Box::pin(_post_query::<'a, T>(
130 qc.clone(),
131 url,
132 access_token,
133 request_body,
134 ))
135 })
136 .await
137}
138
139/// Macro to define types for the Disco GraphQL schema.
140///
141/// This macro simplifies the creation of types associated with GraphQL queries by automatically
142/// deriving necessary traits and setting up schema configurations. It is intended to be used
143/// in place of the `#[graphql]` attribute macro directly, providing a more concise syntax and
144/// ensuring consistent configuration across different query types.
145///
146/// The macro generates a struct with the given identifier that represents a GraphQL query or mutation,
147/// along with associated traits like `Debug`, `Default`, `PartialEq`, and `Serialize`. It also
148/// handles custom scalar types, particularly mapping the GraphQL "JSON" type to the Rust equivalent.
149///
150/// # Usage
151///
152/// `disco_api!` macro is used by passing the name of the struct to be generated and the GraphQL
153/// operation name. The macro expands to a struct definition with the `#[derive(GraphQLQuery)]`
154/// attribute, along with necessary configurations.
155///
156/// # Example
157///
158/// ```
159/// # use graphql_client::GraphQLQuery;
160/// # use serde_json::Value as JSON;
161/// # use discovery_connect::disco_api;
162/// // TODO: the following line fails in doc tests but works in regular code:
163/// // disco_api!(CreateFile, "createFile", "src/api/graphql/CreateFile.graphql");
164/// ```
165///
166/// This example creates a `CreateFile` struct, which is tied to the "createFile" GraphQL operation
167/// defined in file "src/api/graphql/CreateFile.graphql".
168/// The struct will automatically derive `Debug`, `Default`, `PartialEq`, and `Serialize` traits, and
169/// will use custom scalar configuration for handling JSON types.
170///
171/// # Parameters
172///
173/// * `$name`: The struct name to be generated.
174/// * `$operation_name`: The name of the GraphQL operation this struct represents.
175#[macro_export]
176macro_rules! disco_api {
177 ($name:ident, $operation_name:expr, $query_path:expr) => {
178 #[derive(GraphQLQuery)]
179 #[graphql(
180 query_path = $query_path,
181 schema_path = "src/api/graphql/schema.graphql",
182 response_derives = "Clone,Debug,Default,PartialEq,Deserialize,Serialize",
183 variables_derives = "Clone",
184 operation_name = $operation_name,
185 custom_scalars(
186 graphql_type = "JSON",
187 rust_type = "serde_json::Value",
188 serialize_with = "graphql_client::serialization::AsJson",
189 deserialize_with = "graphql_client::deprecation::FromJson"
190 )
191 )]
192 pub struct $name;
193 };
194}
195
196/// Represents the authentication configuration for a service.
197///
198/// This structure holds various credentials and tokens necessary for authenticating
199/// with a specific service. It includes the service's URL, user credentials, and tokens
200/// used for accessing protected resources.
201///
202/// # Fields
203///
204/// * `url` - A `String` representing the URL of the service to which queries will be sent.
205/// * `email` - A `String` representing the user's email address used for authentication.
206/// * `region` - A `Region` enum representing the region of the service.
207/// * `password` - A `String` representing the user's password used for authentication.
208/// * `access_token` - An `Option<String>` that holds the access token if available.
209/// This token is used to authenticate queries. It can be `None` if the token has not been acquired yet.
210/// * `refresh_token` - An `Option<String>` that holds the refresh token if available.
211/// This token is used to obtain a new access token when the current one expires.
212/// It can be `None` if the token has not been acquired yet.
213///
214/// # Examples
215///
216/// ```
217/// use discovery_connect::{QueryConfig};
218///
219/// let query_config = QueryConfig {
220/// client_id: "client_id".to_string(),
221/// client_secret: "client_secret".to_string(),
222/// url: "https://api.europe.discovery.retinai.com".to_string(),
223/// email: "user@example.com".to_string(),
224/// password: "password123".to_string(),
225/// access_token: None,
226/// refresh_token: None,
227/// };
228/// ```
229#[derive(Clone, Debug)]
230pub struct QueryConfig {
231 pub url: String,
232 pub client_id: String,
233 pub client_secret: String,
234 pub email: String,
235 pub password: String,
236 pub access_token: Option<String>,
237 pub refresh_token: Option<String>,
238}
239
240/// Encapsulates a client for performing authenticated queries to a service.
241///
242/// # Fields
243///
244/// * `client` - An instance of `reqwest::Client` used for making HTTP requests.
245/// * `config` - A `Mutex<QueryConfig>` that holds the authentication configuration.
246/// The `Mutex` ensures that the configuration can be safely accessed and modified
247/// across multiple threads.
248///
249/// # Examples
250///
251/// ```
252/// use discovery_connect::{QueryClient, QueryConfig};
253/// use std::sync::Mutex;
254///
255/// let query_client = QueryClient {
256/// client: reqwest::Client::new(),
257/// config: Mutex::new(QueryConfig {
258/// client_id: "client_id".to_string(),
259/// client_secret: "client_secret".to_string(),
260/// url: "https://api.europe.discovery.retinai.com".to_string(),
261/// email: "user@example.com".to_string(),
262/// password: "password123".to_string(),
263/// access_token: None,
264/// refresh_token: None,
265/// }),
266/// };
267/// ```
268#[derive(Debug)]
269pub struct QueryClient {
270 pub config: Mutex<QueryConfig>,
271 pub client: reqwest::Client,
272}
273
274impl QueryClient {
275 pub fn new(
276 url: &str,
277 client_id: &str,
278 client_secret: &str,
279 email: &str,
280 password: &str,
281 timeout: std::time::Duration,
282 ) -> QueryClient {
283 let client = reqwest::Client::builder()
284 .timeout(timeout)
285 .build()
286 .unwrap_or_default();
287
288 QueryClient {
289 client,
290 config: Mutex::new(QueryConfig {
291 url: url.to_string(),
292 client_id: client_id.to_string(),
293 client_secret: client_secret.to_string(),
294 email: email.to_string(),
295 password: password.to_string(),
296 access_token: None,
297 refresh_token: None,
298 }),
299 }
300 }
301}
302
303/// Executes a provided asynchronous function, handling token authentication.
304///
305/// This function takes a configuration object and an asynchronous function `func`.
306/// It attempts to execute `func`. If `func` fails due to an unauthorized error,
307/// `with_fresh_token` tries to refresh the access token using the refresh token in the configuration.
308/// If refreshing the token also fails with an unauthorized error, it attempts to log in again
309/// with the credentials in the configuration to obtain new tokens.
310///
311/// # Arguments
312///
313/// * `config` - A shared, thread-safe reference (`Arc`) to an `QueryConfig` object
314/// containing authentication details like URL, email, password, and tokens.
315/// * `func` - The asynchronous function to be executed. This function should return a `Future`
316/// that resolves to a `Result<R, reqwest::Error>`. `R` is the expected successful return type of the function.
317///
318/// # Returns
319///
320/// This function returns a `Future` that resolves to a `Result<R, reqwest::Error>`.
321/// `R` is the return type of the provided asynchronous function `func`.
322/// If `func` executes successfully, its result is returned. If `func` fails due to an unauthorized error,
323/// a token refresh or re-login is attempted, and `func` is retried.
324/// If the token refresh or re-login fails, or if `func` fails due to any other error,
325/// the error is propagated as the return value.
326///
327/// # Examples
328///
329/// ```
330/// use discovery_connect::{QueryConfig, QueryClient, with_fresh_token};
331/// use std::sync::Arc;
332/// use std::sync::Mutex;
333///
334/// async fn example() {
335/// let config = Arc::new(QueryClient {
336/// client: reqwest::Client::new(),
337/// config: Mutex::new(QueryConfig {
338/// client_id: "client_id".to_string(),
339/// client_secret: "client_secret".to_string(),
340/// url: "https://api.europe.discovery.retinai.com".to_string(),
341/// email: "user@example".to_string(),
342/// password: "password123".to_string(),
343/// access_token: None,
344/// refresh_token: None,
345/// })
346/// });
347/// let result = with_fresh_token(config, move || {
348/// // do something that requires authentication here
349/// // ...
350/// Box::pin(async move { Ok(()) })
351/// }).await;
352/// }
353/// ```
354pub async fn with_fresh_token<'a, F, R>(qc: Arc<QueryClient>, func: F) -> Result<R, reqwest::Error>
355where
356 F: Fn() -> Pin<Box<dyn Future<Output = Result<R, reqwest::Error>> + 'a>> + 'a,
357 R: 'a,
358{
359 {
360 // login if there is no refresh token
361 let config = qc.config.lock().unwrap().clone();
362 let access_token = &config.access_token;
363 let refresh_token = &config.refresh_token;
364 if access_token.is_none() || refresh_token.is_none() {
365 match login(qc.clone(), config).await {
366 Ok(config) => {
367 let mut qc_config = qc.config.lock().unwrap();
368 qc_config.access_token = config.access_token;
369 qc_config.refresh_token = config.refresh_token;
370 }
371 Err(e) => {
372 eprintln!("Error: {:?}", e);
373 return Err(e);
374 }
375 }
376 }
377 }
378
379 // execute the function and check for authorization error
380 match func().await {
381 Ok(result) => Ok(result),
382 Err(e) => {
383 if e.status() == Some(reqwest::StatusCode::UNAUTHORIZED) {
384 {
385 // Try to refresh the access token and retry
386 let config = qc.config.lock().unwrap().clone();
387 let refresh_token = &config.refresh_token;
388 match super::auth::access_token(
389 &qc.client,
390 &config.url,
391 refresh_token.as_ref().unwrap(),
392 )
393 .await
394 {
395 Ok(auth) => {
396 let mut qc_config = qc.config.lock().unwrap();
397 qc_config.access_token = Some(auth.access_token.clone());
398 qc_config.refresh_token = Some(auth.refresh_token.clone());
399 }
400 Err(e) => {
401 eprintln!("Error: {:?}", e);
402 // failed to refresh tokens, clear to force a fresh login
403 let mut qc_config = qc.config.lock().unwrap();
404 qc_config.access_token = None;
405 qc_config.refresh_token = None;
406 return Err(e);
407 }
408 };
409 }
410
411 // tokens refreshed, retry the function
412 func().await
413 } else {
414 eprintln!("Error: {:?}", e);
415 Err(e)
416 }
417 }
418 }
419}
420
421/// Performs a basic password login and returns the authentication response.
422///
423/// # Arguments
424///
425/// * `qc` - A shared, thread-safe reference (`Arc`) to an `QueryClient` object
426/// containing authentication details like URL, email, password, and tokens.
427/// * `config` - A shared, thread-safe reference (`MutexGuard`) to an `QueryConfig` object
428/// containing authentication details like URL, email, password, and tokens.
429///
430/// # Returns
431///
432/// If the login is successful, the `AuthResponse` is returned.
433/// If the login fails, an `reqwest::Error` is returned.
434pub async fn login(
435 qc: Arc<QueryClient>,
436 config: QueryConfig,
437) -> Result<QueryConfig, reqwest::Error> {
438 match super::auth::login(
439 &qc.client,
440 &config.url,
441 &config.client_id,
442 &config.client_secret,
443 &config.email,
444 &config.password,
445 None,
446 )
447 .await
448 {
449 Ok(auth) => Ok(QueryConfig {
450 access_token: Some(auth.access_token.clone()),
451 refresh_token: Some(auth.refresh_token.clone()),
452 ..config
453 }),
454 Err(e) => Err(e),
455 }
456}