#![deny(missing_docs)]
#![deny(rust_2018_idioms)]
mod ws;
use futures_util::{stream::Stream, StreamExt};
use reqwest::{Url, header::{AUTHORIZATION, HeaderMap, HeaderValue}};
use serde::{Deserialize, Serialize};
use tokio_tungstenite::tungstenite::{self, http::Request};
use ws::GraphQLWebSocket;
use std::pin::Pin;
use std::{collections::HashMap, future::Future};
use std::{
fmt::{self, Display},
sync::Mutex,
};
pub trait Executor<'a, T>
where
T: for<'de> Deserialize<'de> + 'a,
{
fn execute(
&'a self,
request_body: RequestBody,
) -> Pin<Box<dyn Future<Output = Result<T, Error>> + 'a>>;
}
pub type SubscriptionStream<T> = Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>;
pub trait Subscriber<T>
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
fn subscribe(
&self,
request_body: RequestBody,
) -> SubscriptionStream<T>;
}
pub struct HttpClient {
http_endpoint: Url,
http: reqwest::Client,
}
pub struct WsClient {
ws: Mutex<GraphQLWebSocket>,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("The response contains errors.")]
GraphQL(Vec<GraphQLError>),
#[error("An HTTP error occurred.")]
Http(#[from] reqwest::Error),
#[error("An error parsing JSON response occurred.")]
Json(#[from] serde_json::Error),
#[error("The server replied with an error payload.")]
Server(serde_json::Value),
#[error("A WebSocket error occurred.")]
WebSocket(#[from] tungstenite::Error),
#[error("The response body is empty.")]
Empty,
}
impl WsClient {
pub async fn new(endpoint: &Url, bearer_token: Option<String>, ws_protocols: Vec<String>) -> Result<Self, tungstenite::Error> {
let mut req = Request::builder().uri(endpoint.as_str());
if let Some(bearer_token) = bearer_token {
req = req.header("Authorization", format!("Bearer {}", bearer_token));
}
if !ws_protocols.is_empty() {
req = req.header("Sec-WebSocket-Protocol", ws_protocols.join(", "))
}
let ws = GraphQLWebSocket::connect(req.body(()).unwrap()).await?;
Ok(Self { ws: Mutex::new(ws) })
}
fn sub_inner<T>(&self, request_body: RequestBody) -> impl Stream<Item = Result<T, Error>> + Send
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
let subscription = {
let mut ws = self.ws.lock().unwrap();
ws.subscribe::<T>(request_body)
};
let mut stream = subscription.stream();
async_stream::stream! {
while let Some(msg) = stream.next().await {
match msg {
Ok(value) => match serde_json::from_value(value) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(Error::Json(e)),
},
Err(err) => yield Err(err),
}
}
}
}
}
impl<T> Subscriber<T> for WsClient
where
T: for<'de> Deserialize<'de> + Unpin + Send + 'static,
{
fn subscribe(
&self,
request_body: RequestBody,
) -> SubscriptionStream<T> {
Box::pin(self.sub_inner(request_body))
}
}
impl<'a, T> Executor<'a, T> for HttpClient
where
T: for<'de> Deserialize<'de> + 'a,
{
fn execute(
&'a self,
request_body: RequestBody,
) -> Pin<Box<dyn Future<Output = Result<T, Error>> + 'a>> {
Box::pin(self.execute_inner(request_body))
}
}
impl HttpClient {
pub fn new(endpoint: &Url, bearer_token: Option<String>) -> Self {
let mut header_map = HeaderMap::new();
if let Some(token) = bearer_token {
header_map.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", token)).unwrap());
}
Self {
http_endpoint: endpoint.clone(),
http: reqwest::Client::builder().default_headers(header_map).build().unwrap(),
}
}
async fn execute_inner<T>(&self, request_body: RequestBody) -> Result<T, Error>
where
T: for<'de> Deserialize<'de>,
{
let response = self
.http
.post(self.http_endpoint.clone())
.json(&request_body)
.send()
.await?;
let body: Response<T> = response.json().await?;
match (body.data, body.errors) {
(None, None) => Err(Error::Empty),
(None, Some(errs)) => Err(Error::GraphQL(errs)),
(Some(data), _) => Ok(data),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestBody {
pub variables: serde_json::Value,
pub query: &'static str,
#[serde(rename = "operationName")]
pub operation_name: &'static str,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq)]
pub struct Location {
pub line: i32,
pub column: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum PathFragment {
Key(String),
Index(i32),
}
impl Display for PathFragment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
PathFragment::Key(ref key) => write!(f, "{}", key),
PathFragment::Index(ref idx) => write!(f, "{}", idx),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GraphQLError {
pub message: String,
pub locations: Option<Vec<Location>>,
pub path: Option<Vec<PathFragment>>,
pub extensions: Option<HashMap<String, serde_json::Value>>,
}
impl Display for GraphQLError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let path = self
.path
.as_ref()
.map(|fragments| {
fragments
.iter()
.fold(String::new(), |mut acc, item| {
acc.push_str(&format!("{}/", item));
acc
})
.trim_end_matches('/')
.to_string()
})
.unwrap_or_else(|| "<query>".to_string());
let loc = self
.locations
.as_ref()
.and_then(|locations| locations.iter().next())
.cloned()
.unwrap_or_else(Location::default);
write!(f, "{}:{}:{}: {}", path, loc.line, loc.column, self.message)
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct Response<Data> {
pub data: Option<Data>,
pub errors: Option<Vec<GraphQLError>>,
}
impl<Data> From<Response<Data>> for Result<Data, Error> {
fn from(res: Response<Data>) -> Self {
match (res.data, res.errors) {
(Some(data), _) => Ok(data),
(None, Some(errs)) => Err(Error::GraphQL(errs)),
(None, None) => Err(Error::Empty),
}
}
}
impl<Data> Clone for Response<Data>
where
Data: Clone,
{
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
errors: self.errors.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn graphql_error_works_with_just_message() {
let err = json!({
"message": "I accidentally your whole query"
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: None,
path: None,
extensions: None,
}
)
}
#[test]
fn full_graphql_error_deserialization() {
let err = json!({
"message": "I accidentally your whole query",
"locations": [{ "line": 3, "column": 13}, {"line": 56, "column": 1}],
"path": ["home", "alone", 3, "rating"]
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: Some(vec![
Location {
line: 3,
column: 13,
},
Location {
line: 56,
column: 1,
},
]),
path: Some(vec![
PathFragment::Key("home".to_owned()),
PathFragment::Key("alone".to_owned()),
PathFragment::Index(3),
PathFragment::Key("rating".to_owned()),
]),
extensions: None,
}
)
}
#[test]
fn full_graphql_error_with_extensions_deserialization() {
let err = json!({
"message": "I accidentally your whole query",
"locations": [{ "line": 3, "column": 13}, {"line": 56, "column": 1}],
"path": ["home", "alone", 3, "rating"],
"extensions": {
"code": "CAN_NOT_FETCH_BY_ID",
"timestamp": "Fri Feb 9 14:33:09 UTC 2018"
}
});
let deserialized_error: GraphQLError = serde_json::from_value(err).unwrap();
let mut expected_extensions = HashMap::new();
expected_extensions.insert("code".to_owned(), json!("CAN_NOT_FETCH_BY_ID"));
expected_extensions.insert("timestamp".to_owned(), json!("Fri Feb 9 14:33:09 UTC 2018"));
let expected_extensions = Some(expected_extensions);
assert_eq!(
deserialized_error,
GraphQLError {
message: "I accidentally your whole query".to_string(),
locations: Some(vec![
Location {
line: 3,
column: 13,
},
Location {
line: 56,
column: 1,
},
]),
path: Some(vec![
PathFragment::Key("home".to_owned()),
PathFragment::Key("alone".to_owned()),
PathFragment::Index(3),
PathFragment::Key("rating".to_owned()),
]),
extensions: expected_extensions,
}
)
}
}