async_graphql_rocket/
lib.rs1#![warn(missing_docs)]
13#![forbid(unsafe_code)]
14#![allow(clippy::blocks_in_conditions)]
15
16use core::any::Any;
17use std::io::Cursor;
18
19use async_graphql::{Executor, ParseRequestError, http::MultipartOptions};
20use rocket::{
21 data::{self, Data, FromData, ToByteUnit},
22 form::FromForm,
23 http::{ContentType, Header, Status},
24 response::{self, Responder},
25};
26use tokio_util::compat::TokioAsyncReadCompatExt;
27
28#[derive(Debug)]
39pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
40
41impl GraphQLBatchRequest {
42 pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
44 where
45 E: Executor,
46 {
47 GraphQLResponse(executor.execute_batch(self.0).await)
48 }
49}
50
51#[rocket::async_trait]
52impl<'r> FromData<'r> for GraphQLBatchRequest {
53 type Error = ParseRequestError;
54
55 async fn from_data(req: &'r rocket::Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
56 let opts: MultipartOptions = req.rocket().state().copied().unwrap_or_default();
57
58 let request = async_graphql::http::receive_batch_body(
59 req.headers().get_one("Content-Type"),
60 data.open(
61 req.limits()
62 .get("graphql")
63 .unwrap_or_else(|| 128.kibibytes()),
64 )
65 .compat(),
66 opts,
67 )
68 .await;
69
70 match request {
71 Ok(request) => data::Outcome::Success(Self(request)),
72 Err(e) => data::Outcome::Error((
73 match e {
74 ParseRequestError::PayloadTooLarge => Status::PayloadTooLarge,
75 _ => Status::BadRequest,
76 },
77 e,
78 )),
79 }
80 }
81}
82
83#[derive(Debug)]
94pub struct GraphQLRequest(pub async_graphql::Request);
95
96impl GraphQLRequest {
97 pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
99 where
100 E: Executor,
101 {
102 GraphQLResponse(executor.execute(self.0).await.into())
103 }
104
105 #[must_use]
107 pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
108 self.0.data.insert(data);
109 self
110 }
111}
112
113impl From<GraphQLQuery> for GraphQLRequest {
114 fn from(query: GraphQLQuery) -> Self {
115 let mut request = async_graphql::Request::new(query.query);
116
117 if let Some(operation_name) = query.operation_name {
118 request = request.operation_name(operation_name);
119 }
120
121 if let Some(variables) = query.variables {
122 let value = serde_json::from_str(&variables).unwrap_or_default();
123 let variables = async_graphql::Variables::from_json(value);
124 request = request.variables(variables);
125 }
126
127 GraphQLRequest(request)
128 }
129}
130
131#[derive(FromForm, Debug)]
142pub struct GraphQLQuery {
143 query: String,
144 #[field(name = "operationName")]
145 operation_name: Option<String>,
146 variables: Option<String>,
147}
148
149impl GraphQLQuery {
150 pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
152 where
153 E: Executor,
154 {
155 let request: GraphQLRequest = self.into();
156 request.execute(executor).await
157 }
158}
159
160#[rocket::async_trait]
161impl<'r> FromData<'r> for GraphQLRequest {
162 type Error = ParseRequestError;
163
164 async fn from_data(req: &'r rocket::Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
165 GraphQLBatchRequest::from_data(req, data)
166 .await
167 .and_then(|request| match request.0.into_single() {
168 Ok(single) => data::Outcome::Success(Self(single)),
169 Err(e) => data::Outcome::Error((Status::BadRequest, e)),
170 })
171 }
172}
173
174#[derive(Debug)]
180pub struct GraphQLResponse(pub async_graphql::BatchResponse);
181
182impl From<async_graphql::BatchResponse> for GraphQLResponse {
183 fn from(batch: async_graphql::BatchResponse) -> Self {
184 Self(batch)
185 }
186}
187impl From<async_graphql::Response> for GraphQLResponse {
188 fn from(res: async_graphql::Response) -> Self {
189 Self(res.into())
190 }
191}
192
193impl<'r> Responder<'r, 'static> for GraphQLResponse {
194 fn respond_to(self, _: &'r rocket::Request<'_>) -> response::Result<'static> {
195 let body = serde_json::to_string(&self.0).unwrap();
196
197 let mut response = rocket::Response::new();
198 response.set_header(ContentType::new("application", "json"));
199
200 if self.0.is_ok() {
201 if let Some(cache_control) = self.0.cache_control().value() {
202 response.set_header(Header::new("cache-control", cache_control));
203 }
204 }
205
206 for (name, value) in self.0.http_headers_iter() {
207 if let Ok(value) = value.to_str() {
208 response.adjoin_header(Header::new(name.as_str().to_string(), value.to_string()));
209 }
210 }
211
212 response.set_sized_body(body.len(), Cursor::new(body));
213
214 Ok(response)
215 }
216}