mod graphiql;
use crate::{Request, Response, Error, Data, Body};
use crate::header::{self, RequestHeader, Method, StatusCode, Mime};
use crate::routes::Route;
use crate::util::PinnedFuture;
use crate::error::{ClientErrorKind};
use std::any::{Any, TypeId};
use juniper::{
RootNode, GraphQLType, GraphQLTypeAsync, GraphQLSubscriptionType,
ScalarValue
};
use juniper::http::{GraphQLRequest, GraphQLBatchRequest};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GraphiQl {
uri: &'static str,
graphql_uri: &'static str
}
impl GraphiQl {
pub const fn new(uri: &'static str, graphql_uri: &'static str) -> Self {
Self { uri, graphql_uri }
}
}
impl Route for GraphiQl {
fn check(&self, header: &RequestHeader) -> bool {
header.method() == &Method::GET &&
header.uri().path().starts_with(self.uri)
}
fn validate_data(&self, _data: &Data) {}
fn call<'a>(
&'a self,
_req: &'a mut Request,
_: &'a Data
) -> PinnedFuture<'a, crate::Result<Response>> {
PinnedFuture::new(async move {
Ok(Response::html(
graphiql::graphiql_source(self.graphql_uri)
))
})
}
}
pub struct GraphQlContext {
data: Data,
request_header: RequestHeader
}
impl GraphQlContext {
pub fn get<D>(&self) -> Option<&D>
where D: Any {
if TypeId::of::<D>() == TypeId::of::<RequestHeader>() {
<dyn Any>::downcast_ref(&self.request_header)
} else {
self.data.get()
}
}
}
impl juniper::Context for GraphQlContext {}
pub struct GraphQl<Q, M, Sub, S>
where
Q: GraphQLType<S, Context=GraphQlContext>,
M: GraphQLType<S, Context=GraphQlContext>,
Sub: GraphQLType<S, Context=GraphQlContext>,
S: ScalarValue
{
uri: &'static str,
root_node: RootNode<'static, Q, M, Sub, S>
}
impl<Q, M, Sub, S> GraphQl<Q, M, Sub, S>
where
Q: GraphQLType<S, Context=GraphQlContext>,
M: GraphQLType<S, Context=GraphQlContext>,
Sub: GraphQLType<S, Context=GraphQlContext>,
S: ScalarValue
{
pub fn new(
uri: &'static str,
root_node: RootNode<'static, Q, M, Sub, S>
) -> Self {
Self { uri, root_node }
}
}
impl<Q, M, Sub, S> Route for GraphQl<Q, M, Sub, S>
where
Q: GraphQLTypeAsync<S, Context=GraphQlContext> + Send,
Q::TypeInfo: Send + Sync,
M: GraphQLTypeAsync<S, Context=GraphQlContext> + Send,
M::TypeInfo: Send + Sync,
Sub: GraphQLSubscriptionType<S, Context=GraphQlContext> + Send,
Sub::TypeInfo: Send + Sync,
S: ScalarValue + Send + Sync
{
fn check(&self, header: &RequestHeader) -> bool {
header.method() == &Method::POST &&
header.uri().path().starts_with(self.uri)
}
fn validate_data(&self, _data: &Data) {}
fn call<'a>(
&'a self,
req: &'a mut Request,
data: &'a Data
) -> PinnedFuture<'a, crate::Result<Response>> {
PinnedFuture::new(async move {
let content_type = req.header().value(header::CONTENT_TYPE)
.unwrap_or("");
let gql_req: GraphQLBatchRequest<S> = match content_type {
"application/json" => {
req.deserialize().await?
},
"application/graphql" => {
let body = req.body.take().into_string().await
.map_err(Error::from_client_io)?;
GraphQLBatchRequest::Single(
GraphQLRequest::new(body, None, None)
)
},
_ => return Err(ClientErrorKind::BadRequest.into())
};
let ctx = GraphQlContext {
data: data.clone(),
request_header: req.header().clone()
};
let res = gql_req.execute(&self.root_node, &ctx).await;
let mut resp = Response::builder()
.content_type(Mime::JSON);
if !res.is_ok() {
resp = resp.status_code(StatusCode::BAD_REQUEST);
}
Ok(resp.body(Body::serialize(&res).unwrap()).build())
})
}
}