async_graphql_actix_web/
request.rs

1use std::{
2    future::Future,
3    io::{self, ErrorKind},
4    pin::Pin,
5    str::FromStr,
6};
7
8use actix_http::{
9    body::BoxBody,
10    error::PayloadError,
11    header::{HeaderName, HeaderValue},
12};
13use actix_web::{
14    Error, FromRequest, HttpRequest, HttpResponse, Responder, Result,
15    dev::Payload,
16    error::JsonPayloadError,
17    http,
18    http::{Method, StatusCode},
19};
20use async_graphql::{ParseRequestError, http::MultipartOptions};
21use futures_util::{
22    StreamExt, TryStreamExt,
23    future::{self, FutureExt},
24};
25
26/// Extractor for GraphQL request.
27///
28/// `async_graphql::http::MultipartOptions` allows to configure extraction
29/// process.
30pub struct GraphQLRequest(pub async_graphql::Request);
31
32impl GraphQLRequest {
33    /// Unwraps the value to `async_graphql::Request`.
34    #[must_use]
35    pub fn into_inner(self) -> async_graphql::Request {
36        self.0
37    }
38}
39
40type BatchToRequestMapper =
41    fn(<<GraphQLBatchRequest as FromRequest>::Future as Future>::Output) -> Result<GraphQLRequest>;
42
43impl FromRequest for GraphQLRequest {
44    type Error = Error;
45    type Future = future::Map<<GraphQLBatchRequest as FromRequest>::Future, BatchToRequestMapper>;
46
47    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
48        GraphQLBatchRequest::from_request(req, payload).map(|res| {
49            Ok(Self(
50                res?.0
51                    .into_single()
52                    .map_err(actix_web::error::ErrorBadRequest)?,
53            ))
54        })
55    }
56}
57
58/// Extractor for GraphQL batch request.
59///
60/// `async_graphql::http::MultipartOptions` allows to configure extraction
61/// process.
62pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
63
64impl GraphQLBatchRequest {
65    /// Unwraps the value to `async_graphql::BatchRequest`.
66    #[must_use]
67    pub fn into_inner(self) -> async_graphql::BatchRequest {
68        self.0
69    }
70}
71
72impl FromRequest for GraphQLBatchRequest {
73    type Error = Error;
74    type Future = Pin<Box<dyn Future<Output = Result<GraphQLBatchRequest>>>>;
75
76    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
77        let config = req
78            .app_data::<MultipartOptions>()
79            .cloned()
80            .unwrap_or_default();
81
82        if req.method() == Method::GET {
83            let res = async_graphql::http::parse_query_string(req.query_string())
84                .map_err(io::Error::other);
85            Box::pin(async move { Ok(Self(async_graphql::BatchRequest::Single(res?))) })
86        } else if req.method() == Method::POST {
87            let content_type = req
88                .headers()
89                .get(http::header::CONTENT_TYPE)
90                .and_then(|value| value.to_str().ok())
91                .map(|value| value.to_string());
92
93            let (tx, rx) = async_channel::bounded(16);
94
95            // Payload is !Send so we create indirection with a channel
96            let mut payload = payload.take();
97            actix::spawn(async move {
98                while let Some(item) = payload.next().await {
99                    if tx.send(item).await.is_err() {
100                        return;
101                    }
102                }
103            });
104
105            Box::pin(async move {
106                Ok(GraphQLBatchRequest(
107                    async_graphql::http::receive_batch_body(
108                        content_type,
109                        rx.map_err(|e| match e {
110                            PayloadError::Incomplete(Some(e)) | PayloadError::Io(e) => e,
111                            PayloadError::Incomplete(None) => {
112                                io::Error::from(ErrorKind::UnexpectedEof)
113                            }
114                            PayloadError::EncodingCorrupted => io::Error::new(
115                                ErrorKind::InvalidData,
116                                "cannot decode content-encoding",
117                            ),
118                            PayloadError::Overflow => io::Error::new(
119                                ErrorKind::InvalidData,
120                                "a payload reached size limit",
121                            ),
122                            PayloadError::UnknownLength => {
123                                io::Error::other("a payload length is unknown")
124                            }
125                            #[cfg(feature = "http2")]
126                            PayloadError::Http2Payload(e) if e.is_io() => e.into_io().unwrap(),
127                            #[cfg(feature = "http2")]
128                            PayloadError::Http2Payload(e) => io::Error::other(e),
129                            _ => io::Error::other(e),
130                        })
131                        .into_async_read(),
132                        config,
133                    )
134                    .await
135                    .map_err(|err| match err {
136                        ParseRequestError::PayloadTooLarge => {
137                            actix_web::error::ErrorPayloadTooLarge(err)
138                        }
139                        _ => actix_web::error::ErrorBadRequest(err),
140                    })?,
141                ))
142            })
143        } else {
144            Box::pin(async move {
145                Err(actix_web::error::ErrorMethodNotAllowed(
146                    "GraphQL only supports GET and POST requests",
147                ))
148            })
149        }
150    }
151}
152
153/// Responder for a GraphQL response.
154///
155/// This contains a batch response, but since regular responses are a type of
156/// batch response it works for both.
157pub struct GraphQLResponse(pub async_graphql::BatchResponse);
158
159impl From<async_graphql::Response> for GraphQLResponse {
160    fn from(resp: async_graphql::Response) -> Self {
161        Self(resp.into())
162    }
163}
164
165impl From<async_graphql::BatchResponse> for GraphQLResponse {
166    fn from(resp: async_graphql::BatchResponse) -> Self {
167        Self(resp)
168    }
169}
170
171#[cfg(feature = "cbor")]
172mod cbor {
173    use core::fmt;
174
175    use actix_web::{ResponseError, http::StatusCode};
176
177    #[derive(Debug)]
178    pub struct Error(pub serde_cbor::Error);
179    impl fmt::Display for Error {
180        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181            write!(f, "{}", self.0)
182        }
183    }
184    impl ResponseError for Error {
185        fn status_code(&self) -> StatusCode {
186            StatusCode::INTERNAL_SERVER_ERROR
187        }
188    }
189}
190
191impl Responder for GraphQLResponse {
192    type Body = BoxBody;
193
194    fn respond_to(self, req: &HttpRequest) -> HttpResponse {
195        let mut builder = HttpResponse::build(StatusCode::OK);
196
197        if self.0.is_ok() {
198            if let Some(cache_control) = self.0.cache_control().value() {
199                builder.append_header((http::header::CACHE_CONTROL, cache_control));
200            }
201        }
202
203        let accept = req
204            .headers()
205            .get(http::header::ACCEPT)
206            .and_then(|val| val.to_str().ok());
207        let (ct, body) = match accept {
208            // optional cbor support
209            #[cfg(feature = "cbor")]
210            // this avoids copy-pasting the mime type
211            Some(ct @ "application/cbor") => (
212                ct,
213                match serde_cbor::to_vec(&self.0) {
214                    Ok(body) => body,
215                    Err(e) => return HttpResponse::from_error(cbor::Error(e)),
216                },
217            ),
218            _ => (
219                "application/graphql-response+json",
220                match serde_json::to_vec(&self.0) {
221                    Ok(body) => body,
222                    Err(e) => return HttpResponse::from_error(JsonPayloadError::Serialize(e)),
223                },
224            ),
225        };
226
227        let mut resp = builder.content_type(ct).body(body);
228
229        for (name, value) in self.0.http_headers_iter() {
230            if let (Ok(name), Ok(value)) = (
231                HeaderName::from_str(name.as_str()),
232                HeaderValue::from_bytes(value.as_bytes()),
233            ) {
234                resp.headers_mut().append(name, value);
235            }
236        }
237
238        resp
239    }
240}