async_graphql_actix_web/
request.rs1use std::{
2 future::Future,
3 io::{self, ErrorKind},
4 pin::Pin,
5 str::FromStr,
6};
7
8use actix_http::{
9 body::BoxBody,
10 error::PayloadError,
11 header::{HeaderName, HeaderValue},
12};
13use actix_web::{
14 Error, FromRequest, HttpRequest, HttpResponse, Responder, Result,
15 dev::Payload,
16 error::JsonPayloadError,
17 http,
18 http::{Method, StatusCode},
19};
20use async_graphql::{ParseRequestError, http::MultipartOptions};
21use futures_util::{
22 StreamExt, TryStreamExt,
23 future::{self, FutureExt},
24};
25
26pub struct GraphQLRequest(pub async_graphql::Request);
31
32impl GraphQLRequest {
33 #[must_use]
35 pub fn into_inner(self) -> async_graphql::Request {
36 self.0
37 }
38}
39
40type BatchToRequestMapper =
41 fn(<<GraphQLBatchRequest as FromRequest>::Future as Future>::Output) -> Result<GraphQLRequest>;
42
43impl FromRequest for GraphQLRequest {
44 type Error = Error;
45 type Future = future::Map<<GraphQLBatchRequest as FromRequest>::Future, BatchToRequestMapper>;
46
47 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
48 GraphQLBatchRequest::from_request(req, payload).map(|res| {
49 Ok(Self(
50 res?.0
51 .into_single()
52 .map_err(actix_web::error::ErrorBadRequest)?,
53 ))
54 })
55 }
56}
57
58pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
63
64impl GraphQLBatchRequest {
65 #[must_use]
67 pub fn into_inner(self) -> async_graphql::BatchRequest {
68 self.0
69 }
70}
71
72impl FromRequest for GraphQLBatchRequest {
73 type Error = Error;
74 type Future = Pin<Box<dyn Future<Output = Result<GraphQLBatchRequest>>>>;
75
76 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
77 let config = req
78 .app_data::<MultipartOptions>()
79 .cloned()
80 .unwrap_or_default();
81
82 if req.method() == Method::GET {
83 let res = async_graphql::http::parse_query_string(req.query_string())
84 .map_err(io::Error::other);
85 Box::pin(async move { Ok(Self(async_graphql::BatchRequest::Single(res?))) })
86 } else if req.method() == Method::POST {
87 let content_type = req
88 .headers()
89 .get(http::header::CONTENT_TYPE)
90 .and_then(|value| value.to_str().ok())
91 .map(|value| value.to_string());
92
93 let (tx, rx) = async_channel::bounded(16);
94
95 let mut payload = payload.take();
97 actix::spawn(async move {
98 while let Some(item) = payload.next().await {
99 if tx.send(item).await.is_err() {
100 return;
101 }
102 }
103 });
104
105 Box::pin(async move {
106 Ok(GraphQLBatchRequest(
107 async_graphql::http::receive_batch_body(
108 content_type,
109 rx.map_err(|e| match e {
110 PayloadError::Incomplete(Some(e)) | PayloadError::Io(e) => e,
111 PayloadError::Incomplete(None) => {
112 io::Error::from(ErrorKind::UnexpectedEof)
113 }
114 PayloadError::EncodingCorrupted => io::Error::new(
115 ErrorKind::InvalidData,
116 "cannot decode content-encoding",
117 ),
118 PayloadError::Overflow => io::Error::new(
119 ErrorKind::InvalidData,
120 "a payload reached size limit",
121 ),
122 PayloadError::UnknownLength => {
123 io::Error::other("a payload length is unknown")
124 }
125 #[cfg(feature = "http2")]
126 PayloadError::Http2Payload(e) if e.is_io() => e.into_io().unwrap(),
127 #[cfg(feature = "http2")]
128 PayloadError::Http2Payload(e) => io::Error::other(e),
129 _ => io::Error::other(e),
130 })
131 .into_async_read(),
132 config,
133 )
134 .await
135 .map_err(|err| match err {
136 ParseRequestError::PayloadTooLarge => {
137 actix_web::error::ErrorPayloadTooLarge(err)
138 }
139 _ => actix_web::error::ErrorBadRequest(err),
140 })?,
141 ))
142 })
143 } else {
144 Box::pin(async move {
145 Err(actix_web::error::ErrorMethodNotAllowed(
146 "GraphQL only supports GET and POST requests",
147 ))
148 })
149 }
150 }
151}
152
153pub struct GraphQLResponse(pub async_graphql::BatchResponse);
158
159impl From<async_graphql::Response> for GraphQLResponse {
160 fn from(resp: async_graphql::Response) -> Self {
161 Self(resp.into())
162 }
163}
164
165impl From<async_graphql::BatchResponse> for GraphQLResponse {
166 fn from(resp: async_graphql::BatchResponse) -> Self {
167 Self(resp)
168 }
169}
170
171#[cfg(feature = "cbor")]
172mod cbor {
173 use core::fmt;
174
175 use actix_web::{ResponseError, http::StatusCode};
176
177 #[derive(Debug)]
178 pub struct Error(pub serde_cbor::Error);
179 impl fmt::Display for Error {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 write!(f, "{}", self.0)
182 }
183 }
184 impl ResponseError for Error {
185 fn status_code(&self) -> StatusCode {
186 StatusCode::INTERNAL_SERVER_ERROR
187 }
188 }
189}
190
191impl Responder for GraphQLResponse {
192 type Body = BoxBody;
193
194 fn respond_to(self, req: &HttpRequest) -> HttpResponse {
195 let mut builder = HttpResponse::build(StatusCode::OK);
196
197 if self.0.is_ok() {
198 if let Some(cache_control) = self.0.cache_control().value() {
199 builder.append_header((http::header::CACHE_CONTROL, cache_control));
200 }
201 }
202
203 let accept = req
204 .headers()
205 .get(http::header::ACCEPT)
206 .and_then(|val| val.to_str().ok());
207 let (ct, body) = match accept {
208 #[cfg(feature = "cbor")]
210 Some(ct @ "application/cbor") => (
212 ct,
213 match serde_cbor::to_vec(&self.0) {
214 Ok(body) => body,
215 Err(e) => return HttpResponse::from_error(cbor::Error(e)),
216 },
217 ),
218 _ => (
219 "application/graphql-response+json",
220 match serde_json::to_vec(&self.0) {
221 Ok(body) => body,
222 Err(e) => return HttpResponse::from_error(JsonPayloadError::Serialize(e)),
223 },
224 ),
225 };
226
227 let mut resp = builder.content_type(ct).body(body);
228
229 for (name, value) in self.0.http_headers_iter() {
230 if let (Ok(name), Ok(value)) = (
231 HeaderName::from_str(name.as_str()),
232 HeaderValue::from_bytes(value.as_bytes()),
233 ) {
234 resp.headers_mut().append(name, value);
235 }
236 }
237
238 resp
239 }
240}