async_graphql_rocket/
lib.rs1#![warn(missing_docs)]
13#![forbid(unsafe_code)]
14#![allow(clippy::blocks_in_conditions)]
15
16use core::any::Any;
17use std::io::Cursor;
18
19use async_graphql::{Executor, ParseRequestError, http::MultipartOptions};
20use rocket::{
21    data::{self, Data, FromData, ToByteUnit},
22    form::FromForm,
23    http::{ContentType, Header, Status},
24    response::{self, Responder},
25};
26use tokio_util::compat::TokioAsyncReadCompatExt;
27
28#[derive(Debug)]
39pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
40
41impl GraphQLBatchRequest {
42    pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
44    where
45        E: Executor,
46    {
47        GraphQLResponse(executor.execute_batch(self.0).await)
48    }
49}
50
51#[rocket::async_trait]
52impl<'r> FromData<'r> for GraphQLBatchRequest {
53    type Error = ParseRequestError;
54
55    async fn from_data(req: &'r rocket::Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
56        let opts: MultipartOptions = req.rocket().state().copied().unwrap_or_default();
57
58        let request = async_graphql::http::receive_batch_body(
59            req.headers().get_one("Content-Type"),
60            data.open(
61                req.limits()
62                    .get("graphql")
63                    .unwrap_or_else(|| 128.kibibytes()),
64            )
65            .compat(),
66            opts,
67        )
68        .await;
69
70        match request {
71            Ok(request) => data::Outcome::Success(Self(request)),
72            Err(e) => data::Outcome::Error((
73                match e {
74                    ParseRequestError::PayloadTooLarge => Status::PayloadTooLarge,
75                    _ => Status::BadRequest,
76                },
77                e,
78            )),
79        }
80    }
81}
82
83#[derive(Debug)]
94pub struct GraphQLRequest(pub async_graphql::Request);
95
96impl GraphQLRequest {
97    pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
99    where
100        E: Executor,
101    {
102        GraphQLResponse(executor.execute(self.0).await.into())
103    }
104
105    #[must_use]
107    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
108        self.0.data.insert(data);
109        self
110    }
111}
112
113impl From<GraphQLQuery> for GraphQLRequest {
114    fn from(query: GraphQLQuery) -> Self {
115        let mut request = async_graphql::Request::new(query.query);
116
117        if let Some(operation_name) = query.operation_name {
118            request = request.operation_name(operation_name);
119        }
120
121        if let Some(variables) = query.variables {
122            let value = serde_json::from_str(&variables).unwrap_or_default();
123            let variables = async_graphql::Variables::from_json(value);
124            request = request.variables(variables);
125        }
126
127        GraphQLRequest(request)
128    }
129}
130
131#[derive(FromForm, Debug)]
142pub struct GraphQLQuery {
143    query: String,
144    #[field(name = "operationName")]
145    operation_name: Option<String>,
146    variables: Option<String>,
147}
148
149impl GraphQLQuery {
150    pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
152    where
153        E: Executor,
154    {
155        let request: GraphQLRequest = self.into();
156        request.execute(executor).await
157    }
158}
159
160#[rocket::async_trait]
161impl<'r> FromData<'r> for GraphQLRequest {
162    type Error = ParseRequestError;
163
164    async fn from_data(req: &'r rocket::Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
165        GraphQLBatchRequest::from_data(req, data)
166            .await
167            .and_then(|request| match request.0.into_single() {
168                Ok(single) => data::Outcome::Success(Self(single)),
169                Err(e) => data::Outcome::Error((Status::BadRequest, e)),
170            })
171    }
172}
173
174#[derive(Debug)]
180pub struct GraphQLResponse(pub async_graphql::BatchResponse);
181
182impl From<async_graphql::BatchResponse> for GraphQLResponse {
183    fn from(batch: async_graphql::BatchResponse) -> Self {
184        Self(batch)
185    }
186}
187impl From<async_graphql::Response> for GraphQLResponse {
188    fn from(res: async_graphql::Response) -> Self {
189        Self(res.into())
190    }
191}
192
193impl<'r> Responder<'r, 'static> for GraphQLResponse {
194    fn respond_to(self, _: &'r rocket::Request<'_>) -> response::Result<'static> {
195        let body = serde_json::to_string(&self.0).unwrap();
196
197        let mut response = rocket::Response::new();
198        response.set_header(ContentType::new("application", "json"));
199
200        if self.0.is_ok() {
201            if let Some(cache_control) = self.0.cache_control().value() {
202                response.set_header(Header::new("cache-control", cache_control));
203            }
204        }
205
206        for (name, value) in self.0.http_headers_iter() {
207            if let Ok(value) = value.to_str() {
208                response.adjoin_header(Header::new(name.as_str().to_string(), value.to_string()));
209            }
210        }
211
212        response.set_sized_body(body.len(), Cursor::new(body));
213
214        Ok(response)
215    }
216}