1mod graphiql;
3
4use crate::error::ClientErrorKind;
5use crate::header::{self, Method, Mime, RequestHeader, StatusCode};
6use crate::routes::{ParamsNames, PathParams, Route, RoutePath};
7use crate::util::PinnedFuture;
8use crate::{Body, Error, Request, Resources, Response};
9
10use std::any::{Any, TypeId};
11
12use juniper::http::{GraphQLBatchRequest, GraphQLRequest};
13use juniper::{
14 GraphQLSubscriptionType, GraphQLType, GraphQLTypeAsync, RootNode,
15 ScalarValue,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct GraphiQl {
20 uri: &'static str,
21 graphql_uri: &'static str,
22}
23
24impl GraphiQl {
25 pub const fn new(uri: &'static str, graphql_uri: &'static str) -> Self {
26 Self { uri, graphql_uri }
27 }
28}
29
30impl Route for GraphiQl {
31 fn validate_requirements(&self, _params: &ParamsNames, _data: &Resources) {}
32
33 fn path(&self) -> RoutePath {
34 RoutePath {
35 method: Some(Method::GET),
36 path: format!("{}/{{*rem}}", self.uri.trim_end_matches('/')).into(),
37 }
38 }
39
40 fn call<'a>(
41 &'a self,
42 _req: &'a mut Request,
43 _params: &'a PathParams,
44 _: &'a Resources,
45 ) -> PinnedFuture<'a, crate::Result<Response>> {
46 PinnedFuture::new(async move {
47 Ok(Response::html(graphiql::graphiql_source(self.graphql_uri)))
48 })
49 }
50}
51
52pub struct GraphQlContext {
53 data: Resources,
54 request_header: RequestHeader,
55}
56
57impl GraphQlContext {
58 pub fn get<D>(&self) -> Option<&D>
60 where
61 D: Any,
62 {
63 if TypeId::of::<D>() == TypeId::of::<RequestHeader>() {
64 <dyn Any>::downcast_ref(&self.request_header)
65 } else {
66 self.data.get()
67 }
68 }
69}
70
71impl juniper::Context for GraphQlContext {}
72
73pub struct GraphQl<Q, M, Sub, S>
75where
76 Q: GraphQLType<S, Context = GraphQlContext>,
77 M: GraphQLType<S, Context = GraphQlContext>,
78 Sub: GraphQLType<S, Context = GraphQlContext>,
79 S: ScalarValue,
80{
81 uri: &'static str,
82 root_node: RootNode<'static, Q, M, Sub, S>,
83}
84
85impl<Q, M, Sub, S> GraphQl<Q, M, Sub, S>
86where
87 Q: GraphQLType<S, Context = GraphQlContext>,
88 M: GraphQLType<S, Context = GraphQlContext>,
89 Sub: GraphQLType<S, Context = GraphQlContext>,
90 S: ScalarValue,
91{
92 pub fn new(
93 uri: &'static str,
94 root_node: RootNode<'static, Q, M, Sub, S>,
95 ) -> Self {
96 Self { uri, root_node }
97 }
98}
99
100impl<Q, M, Sub, S> Route for GraphQl<Q, M, Sub, S>
101where
102 Q: GraphQLTypeAsync<S, Context = GraphQlContext> + Send,
103 Q::TypeInfo: Send + Sync,
104 M: GraphQLTypeAsync<S, Context = GraphQlContext> + Send,
105 M::TypeInfo: Send + Sync,
106 Sub: GraphQLSubscriptionType<S, Context = GraphQlContext> + Send,
107 Sub::TypeInfo: Send + Sync,
108 S: ScalarValue + Send + Sync,
109{
110 fn validate_requirements(&self, _params: &ParamsNames, _data: &Resources) {}
111
112 fn path(&self) -> RoutePath {
113 RoutePath {
114 method: Some(Method::POST),
115 path: self.uri.into(),
116 }
117 }
118
119 fn call<'a>(
120 &'a self,
121 req: &'a mut Request,
122 _params: &'a PathParams,
123 data: &'a Resources,
124 ) -> PinnedFuture<'a, crate::Result<Response>> {
125 PinnedFuture::new(async move {
126 let content_type =
128 req.header().value(header::CONTENT_TYPE).unwrap_or("");
129
130 let gql_req: GraphQLBatchRequest<S> = match content_type {
131 "application/json" => {
132 req.deserialize().await?
134 }
135 "application/graphql" => {
136 let body = req
137 .body
138 .take()
139 .into_string()
140 .await
141 .map_err(Error::from_client_io)?;
142
143 GraphQLBatchRequest::Single(GraphQLRequest::new(
144 body, None, None,
145 ))
146 }
147 _ => return Err(ClientErrorKind::BadRequest.into()),
148 };
149
150 let ctx = GraphQlContext {
151 data: data.clone(),
152 request_header: req.header().clone(),
153 };
154 let res = gql_req.execute(&self.root_node, &ctx).await;
155
156 let mut resp = Response::builder().content_type(Mime::JSON);
157
158 if !res.is_ok() {
159 resp = resp.status_code(StatusCode::BAD_REQUEST);
160 }
161
162 Ok(resp.body(Body::serialize(&res).unwrap()).build())
163 })
164 }
165}