#![deny(missing_docs)]
#![deny(rust_2018_idioms)]
mod ws;
use async_trait::async_trait;
use futures_util::{stream::Stream, SinkExt, StreamExt};
use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION},
Url,
};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::{self, http::Request};
use ws::GraphQLWebSocket;
use std::pin::Pin;
use std::collections::HashMap;
use std::{
fmt::{self, Display},
};
#[async_trait]
pub trait Executor<'a, T>: Sync
where
T: for<'de> Deserialize<'de> + 'a,
{
async fn execute(&'a self, request_body: RequestBody) -> Result<T, Error>;
}
pub type SubscriptionStream<T> = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>;
#[async_trait]
pub trait Subscriber<T>: Sync
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
async fn subscribe(
&self,
request_body: RequestBody,
) -> Result<SubscriptionStream<T>, Error>;
}
pub struct HttpClient {
http_endpoint: Url,
http: reqwest::Client,
}
pub struct WsClient {
ws: Mutex<GraphQLWebSocket>,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("The response contains errors.")]
GraphQL(Vec<GraphQLError>),
#[error("An HTTP error occurred.")]
Http(#[from] reqwest::Error),
#[error("An error parsing JSON response occurred.")]
Json(#[from] serde_json::Error),
#[error("The server replied with an error payload.")]
Server(serde_json::Value),
#[error("A WebSocket error occurred.")]
WebSocket(#[from] tungstenite::Error),
#[error("The response body is empty.")]
Empty,
}
impl WsClient {
pub async fn new(
endpoint: &Url,
bearer_token: Option<String>,
ws_protocols: Vec<String>,
) -> Result<Self, tungstenite::Error> {
let mut req = Request::builder().uri(endpoint.as_str());
if let Some(bearer_token) = bearer_token {
req = req.header("Authorization", format!("Bearer {}", bearer_token));
}
if !ws_protocols.is_empty() {
req = req.header("Sec-WebSocket-Protocol", ws_protocols.join(", "))
}
let ws = GraphQLWebSocket::connect(req.body(()).unwrap()).await?;
Ok(Self { ws: Mutex::new(ws) })
}
async fn sub_inner<T>(&self, request_body: RequestBody) -> Result<SubscriptionStream<T>, Error>
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
let (tx, rx) = futures::channel::mpsc::unbounded();
let subscription = {
let mut ws = self.ws.lock().await;
ws.subscribe::<T>(request_body).await?
};
tokio::spawn(async move {
let mut tx = tx;
let mut stream = subscription.stream();
while let Some(msg) = stream.next().await {
match msg {
Ok(value) => match serde_json::from_value(value) {
Ok(v) => tx.send(Ok(v)).await.unwrap_or(()),
Err(e) => tx.send(Err(Error::Json(e))).await.unwrap_or(()),
},
Err(err) => tx.send(Err(err)).await.unwrap_or(()),
}
}
});
Ok(Box::pin(rx))
}
}
#[async_trait]
impl<T> Subscriber<T> for WsClient
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
async fn subscribe(&self, request_body: RequestBody) -> Result<SubscriptionStream<T>, Error> {
self.sub_inner(request_body).await
}
}
#[async_trait]
impl<'a, T> Executor<'a, T> for HttpClient
where
T: for<'de> Deserialize<'de> + 'a,
{
async fn execute(&'a self, request_body: RequestBody) -> Result<T, Error> {
self.execute_inner(request_body).await
}
}
impl HttpClient {
pub fn new(endpoint: &Url, bearer_token: Option<String>) -> Self {
let mut header_map = HeaderMap::new();
if let Some(token) = bearer_token {
header_map.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(),
);
}
Self {
http_endpoint: endpoint.clone(),
http: reqwest::Client::builder()
.default_headers(header_map)
.build()
.unwrap(),
}
}
async fn execute_inner<T>(&self, request_body: RequestBody) -> Result<T, Error>
where
T: for<'de> Deserialize<'de>,
{
let response = self
.http
.post(self.http_endpoint.clone())
.json(&request_body)
.send()
.await?;
let body: Response<T> = response.json().await?;
match (body.data, body.errors) {
(None, None) => Err(Error::Empty),
(None, Some(errs)) => Err(Error::GraphQL(errs)),
(Some(data), _) => Ok(data),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestBody {
pub variables: serde_json::Value,
pub query: &'static str,
#[serde(rename = "operationName")]
pub operation_name: &'static str,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
pub struct Location {
pub line: i32,
pub column: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum PathFragment {
Key(String),
Index(i32),
}
impl Display for PathFragment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
PathFragment::Key(ref key) => write!(f, "{}", key),
PathFragment::Index(ref idx) => write!(f, "{}", idx),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GraphQLError {
pub message: String,
pub locations: Option<Vec<Location>>,
pub path: Option<Vec<PathFragment>>,
pub extensions: Option<HashMap<String, serde_json::Value>>,
}
impl Display for GraphQLError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let path = self
.path
.as_ref()
.map(|fragments| {
fragments
.iter()
.fold(String::new(), |mut acc, item| {
acc.push_str(&format!("{}/", item));
acc
})
.trim_end_matches('/')
.to_string()
})
.unwrap_or_else(|| "<query>".to_string());
let loc = self
.locations
.as_ref()
.and_then(|locations| locations.iter().next())
.cloned()
.unwrap_or_else(Location::default);
write!(f, "{}:{}:{}: {}", path, loc.line, loc.column, self.message)
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Response<Data> {
pub data: Option<Data>,
pub errors: Option<Vec<GraphQLError>>,
}
impl<Data> From<Response<Data>> for Result<Data, Error> {
fn from(res: Response<Data>) -> Self {
match (res.data, res.errors) {
(Some(data), _) => Ok(data),
(None, Some(errs)) => Err(Error::GraphQL(errs)),
(None, None) => Err(Error::Empty),
}
}
}
impl<Data> Clone for Response<Data>
where
Data: Clone,
{
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
errors: self.errors.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn graphql_error_works_with_just_message() {
let err = json!({
"message": "I accidentally your whole query"
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: None,
path: None,
extensions: None,
}
)
}
#[test]
fn full_graphql_error_deserialization() {
let err = json!({
"message": "I accidentally your whole query",
"locations": [{ "line": 3, "column": 13}, {"line": 56, "column": 1}],
"path": ["home", "alone", 3, "rating"]
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: Some(vec![
Location {
line: 3,
column: 13,
},
Location {
line: 56,
column: 1,
},
]),
path: Some(vec![
PathFragment::Key("home".to_owned()),
PathFragment::Key("alone".to_owned()),
PathFragment::Index(3),
PathFragment::Key("rating".to_owned()),
]),
extensions: None,
}
)
}
#[test]
fn full_graphql_error_with_extensions_deserialization() {
let err = json!({
"message": "I accidentally your whole query",
"locations": [{ "line": 3, "column": 13}, {"line": 56, "column": 1}],
"path": ["home", "alone", 3, "rating"],
"extensions": {
"code": "CAN_NOT_FETCH_BY_ID",
"timestamp": "Fri Feb 9 14:33:09 UTC 2018"
}
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
let mut expected_extensions = HashMap::new();
expected_extensions.insert("code".to_owned(), json!("CAN_NOT_FETCH_BY_ID"));
expected_extensions.insert("timestamp".to_owned(), json!("Fri Feb 9 14:33:09 UTC 2018"));
let expected_extensions = Some(expected_extensions);
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: Some(vec![
Location {
line: 3,
column: 13,
},
Location {
line: 56,
column: 1,
},
]),
path: Some(vec![
PathFragment::Key("home".to_owned()),
PathFragment::Key("alone".to_owned()),
PathFragment::Index(3),
PathFragment::Key("rating".to_owned()),
]),
extensions: expected_extensions,
}
)
}
}