use graphql_client::GraphQLQuery;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
pub async fn post_query<'a, T: GraphQLQuery + 'a>(
qc: Arc<QueryClient>,
variables: <T as GraphQLQuery>::Variables,
) -> Result<T::ResponseData, reqwest::Error>
where
<T as GraphQLQuery>::Variables: Clone,
{
async fn _post_query<'a, T: GraphQLQuery + 'a>(
qc: Arc<QueryClient>,
url: String,
access_token: String,
request_body: graphql_client::QueryBody<<T as GraphQLQuery>::Variables>,
) -> Result<T::ResponseData, reqwest::Error> {
let result = qc
.client
.post(url)
.bearer_auth(access_token)
.json(&request_body);
let result = result.send().await;
match result {
Ok(result) => {
if result.status() != reqwest::StatusCode::OK {
let e = result.error_for_status().unwrap_err();
eprintln!("Error: {:?}", e);
return Err(e);
}
result.text().await.map(|text| {
let json: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
let data = json.get("data");
let data = serde_json::to_string(&data).unwrap_or_default();
match serde_json::from_str(&data) {
Ok(r) => r,
Err(e) => {
eprintln!("Error: {:?}", e);
panic!("Failed to parse response");
}
}
})
}
Err(e) => {
eprintln!("Error: {:?}", e);
Err(e)
}
}
}
with_fresh_token(qc.clone(), move || {
let url = qc.config.lock().unwrap().url.clone();
let url = format!("{}/api", url).clone();
let access_token: String = match qc.config.lock() {
Ok(config) => config.access_token.clone().unwrap(),
Err(e) => {
panic!("Mutex poisoned: {:?}", e);
}
};
let request_body = T::build_query(variables.clone());
Box::pin(_post_query::<'a, T>(
qc.clone(),
url,
access_token,
request_body,
))
})
.await
}
#[macro_export]
macro_rules! disco_api {
($name:ident, $operation_name:expr, $query_path:expr) => {
#[derive(GraphQLQuery)]
#[graphql(
query_path = $query_path,
schema_path = "src/api/graphql/schema.graphql",
response_derives = "Clone,Debug,Default,PartialEq,Deserialize,Serialize",
variables_derives = "Clone",
operation_name = $operation_name,
custom_scalars(
graphql_type = "JSON",
rust_type = "serde_json::Value",
serialize_with = "graphql_client::serialization::AsJson",
deserialize_with = "graphql_client::deprecation::FromJson"
)
)]
pub struct $name;
};
}
/// Represents the authentication configuration for a service.
#[derive(Clone, Debug)]
pub struct QueryConfig {
pub url: String,
pub client_id: String,
pub client_secret: String,
pub email: String,
pub password: String,
pub access_token: Option<String>,
pub refresh_token: Option<String>,
}
#[derive(Debug)]
pub struct QueryClient {
pub config: Mutex<QueryConfig>,
pub client: reqwest::Client,
}
impl QueryClient {
pub fn new(
url: &str,
client_id: &str,
client_secret: &str,
email: &str,
password: &str,
timeout: std::time::Duration,
) -> QueryClient {
let client = reqwest::Client::builder()
.timeout(timeout)
.build()
.unwrap_or_default();
QueryClient {
client,
config: Mutex::new(QueryConfig {
url: url.to_string(),
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
email: email.to_string(),
password: password.to_string(),
access_token: None,
refresh_token: None,
}),
}
}
}
pub async fn with_fresh_token<'a, F, R>(qc: Arc<QueryClient>, func: F) -> Result<R, reqwest::Error>
where
F: Fn() -> Pin<Box<dyn Future<Output = Result<R, reqwest::Error>> + 'a>> + 'a,
R: 'a,
{
{
let config = qc.config.lock().unwrap().clone();
let access_token = &config.access_token;
let refresh_token = &config.refresh_token;
if access_token.is_none() || refresh_token.is_none() {
match login(qc.clone(), config).await {
Ok(config) => {
let mut qc_config = qc.config.lock().unwrap();
qc_config.access_token = config.access_token;
qc_config.refresh_token = config.refresh_token;
}
Err(e) => {
eprintln!("Error: {:?}", e);
return Err(e);
}
}
}
}
match func().await {
Ok(result) => Ok(result),
Err(e) => {
if e.status() == Some(reqwest::StatusCode::UNAUTHORIZED) {
{
let config = qc.config.lock().unwrap().clone();
let refresh_token = &config.refresh_token;
match super::auth::access_token(
&qc.client,
&config.url,
refresh_token.as_ref().unwrap(),
)
.await
{
Ok(auth) => {
let mut qc_config = qc.config.lock().unwrap();
qc_config.access_token = Some(auth.access_token.clone());
qc_config.refresh_token = Some(auth.refresh_token.clone());
}
Err(e) => {
eprintln!("Error: {:?}", e);
let mut qc_config = qc.config.lock().unwrap();
qc_config.access_token = None;
qc_config.refresh_token = None;
return Err(e);
}
};
}
func().await
} else {
eprintln!("Error: {:?}", e);
Err(e)
}
}
}
}
pub async fn login(
qc: Arc<QueryClient>,
config: QueryConfig,
) -> Result<QueryConfig, reqwest::Error> {
match super::auth::login(
&qc.client,
&config.url,
&config.client_id,
&config.client_secret,
&config.email,
&config.password,
None,
)
.await
{
Ok(auth) => Ok(QueryConfig {
access_token: Some(auth.access_token.clone()),
refresh_token: Some(auth.refresh_token.clone()),
..config
}),
Err(e) => Err(e),
}
}