chuchi/graphql/
mod.rs

1/// This is unstable
2mod graphiql;
3
4use crate::error::ClientErrorKind;
5use crate::header::{self, Method, Mime, RequestHeader, StatusCode};
6use crate::routes::{ParamsNames, PathParams, Route, RoutePath};
7use crate::util::PinnedFuture;
8use crate::{Body, Error, Request, Resources, Response};
9
10use std::any::{Any, TypeId};
11
12use juniper::http::{GraphQLBatchRequest, GraphQLRequest};
13use juniper::{
14	GraphQLSubscriptionType, GraphQLType, GraphQLTypeAsync, RootNode,
15	ScalarValue,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct GraphiQl {
20	uri: &'static str,
21	graphql_uri: &'static str,
22}
23
24impl GraphiQl {
25	pub const fn new(uri: &'static str, graphql_uri: &'static str) -> Self {
26		Self { uri, graphql_uri }
27	}
28}
29
30impl Route for GraphiQl {
31	fn validate_requirements(&self, _params: &ParamsNames, _data: &Resources) {}
32
33	fn path(&self) -> RoutePath {
34		RoutePath {
35			method: Some(Method::GET),
36			path: format!("{}/{{*?rem}}", self.uri.trim_end_matches('/'))
37				.into(),
38		}
39	}
40
41	fn call<'a>(
42		&'a self,
43		_req: &'a mut Request,
44		_params: &'a PathParams,
45		_: &'a Resources,
46	) -> PinnedFuture<'a, crate::Result<Response>> {
47		PinnedFuture::new(async move {
48			Ok(Response::html(graphiql::graphiql_source(self.graphql_uri)))
49		})
50	}
51}
52
53pub struct GraphQlContext {
54	data: Resources,
55	request_header: RequestHeader,
56}
57
58impl GraphQlContext {
59	// Gets data or RequestHeader
60	pub fn get<D>(&self) -> Option<&D>
61	where
62		D: Any,
63	{
64		if TypeId::of::<D>() == TypeId::of::<RequestHeader>() {
65			<dyn Any>::downcast_ref(&self.request_header)
66		} else {
67			self.data.get()
68		}
69	}
70}
71
72impl juniper::Context for GraphQlContext {}
73
74/// This only supports POST requests
75pub struct GraphQl<Q, M, Sub, S>
76where
77	Q: GraphQLType<S, Context = GraphQlContext>,
78	M: GraphQLType<S, Context = GraphQlContext>,
79	Sub: GraphQLType<S, Context = GraphQlContext>,
80	S: ScalarValue,
81{
82	uri: &'static str,
83	root_node: RootNode<'static, Q, M, Sub, S>,
84}
85
86impl<Q, M, Sub, S> GraphQl<Q, M, Sub, S>
87where
88	Q: GraphQLType<S, Context = GraphQlContext>,
89	M: GraphQLType<S, Context = GraphQlContext>,
90	Sub: GraphQLType<S, Context = GraphQlContext>,
91	S: ScalarValue,
92{
93	pub fn new(
94		uri: &'static str,
95		root_node: RootNode<'static, Q, M, Sub, S>,
96	) -> Self {
97		Self { uri, root_node }
98	}
99}
100
101impl<Q, M, Sub, S> Route for GraphQl<Q, M, Sub, S>
102where
103	Q: GraphQLTypeAsync<S, Context = GraphQlContext> + Send,
104	Q::TypeInfo: Send + Sync,
105	M: GraphQLTypeAsync<S, Context = GraphQlContext> + Send,
106	M::TypeInfo: Send + Sync,
107	Sub: GraphQLSubscriptionType<S, Context = GraphQlContext> + Send,
108	Sub::TypeInfo: Send + Sync,
109	S: ScalarValue + Send + Sync,
110{
111	fn validate_requirements(&self, _params: &ParamsNames, _data: &Resources) {}
112
113	fn path(&self) -> RoutePath {
114		RoutePath {
115			method: Some(Method::POST),
116			path: self.uri.into(),
117		}
118	}
119
120	fn call<'a>(
121		&'a self,
122		req: &'a mut Request,
123		_params: &'a PathParams,
124		data: &'a Resources,
125	) -> PinnedFuture<'a, crate::Result<Response>> {
126		PinnedFuture::new(async move {
127			// get content-type of request
128			let content_type =
129				req.header().value(header::CONTENT_TYPE).unwrap_or("");
130
131			let gql_req: GraphQLBatchRequest<S> = match content_type {
132				"application/json" => {
133					// read json
134					req.deserialize().await?
135				}
136				"application/graphql" => {
137					let body = req
138						.body
139						.take()
140						.into_string()
141						.await
142						.map_err(Error::from_client_io)?;
143
144					GraphQLBatchRequest::Single(GraphQLRequest::new(
145						body, None, None,
146					))
147				}
148				_ => return Err(ClientErrorKind::BadRequest.into()),
149			};
150
151			let ctx = GraphQlContext {
152				data: data.clone(),
153				request_header: req.header().clone(),
154			};
155			let res = gql_req.execute(&self.root_node, &ctx).await;
156
157			let mut resp = Response::builder().content_type(Mime::JSON);
158
159			if !res.is_ok() {
160				resp = resp.status_code(StatusCode::BAD_REQUEST);
161			}
162
163			Ok(resp.body(Body::serialize(&res).unwrap()).build())
164		})
165	}
166}