async_graphql_axum_wasi/
extract.rs1use std::{io::ErrorKind, marker::PhantomData};
2
3use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
4use axum::{
5 extract::{BodyStream, FromRequest},
6 http::{self, Method, Request},
7 response::IntoResponse,
8 BoxError,
9};
10use bytes::Bytes;
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
15 pub async_graphql::Request,
16 PhantomData<R>,
17);
18
19impl<R> GraphQLRequest<R> {
20 #[must_use]
22 pub fn into_inner(self) -> async_graphql::Request {
23 self.0
24 }
25}
26
27pub mod rejection {
29 use async_graphql::ParseRequestError;
30 use axum::{
31 body::{boxed, Body, BoxBody},
32 http,
33 http::StatusCode,
34 response::IntoResponse,
35 };
36
37 pub struct GraphQLRejection(pub ParseRequestError);
39
40 impl IntoResponse for GraphQLRejection {
41 fn into_response(self) -> http::Response<BoxBody> {
42 match self.0 {
43 ParseRequestError::PayloadTooLarge => http::Response::builder()
44 .status(StatusCode::PAYLOAD_TOO_LARGE)
45 .body(boxed(Body::empty()))
46 .unwrap(),
47 bad_request => http::Response::builder()
48 .status(StatusCode::BAD_REQUEST)
49 .body(boxed(Body::from(format!("{:?}", bad_request))))
50 .unwrap(),
51 }
52 }
53 }
54
55 impl From<ParseRequestError> for GraphQLRejection {
56 fn from(err: ParseRequestError) -> Self {
57 GraphQLRejection(err)
58 }
59 }
60}
61
62#[async_trait::async_trait]
63impl<S, B, R> FromRequest<S, B> for GraphQLRequest<R>
64where
65 B: http_body::Body + Unpin + Send + Sync + 'static,
66 B::Data: Into<Bytes>,
67 B::Error: Into<BoxError>,
68 S: Send + Sync,
69 R: IntoResponse + From<ParseRequestError>,
70{
71 type Rejection = R;
72
73 async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
74 Ok(GraphQLRequest(
75 GraphQLBatchRequest::<R>::from_request(req, state)
76 .await?
77 .0
78 .into_single()?,
79 PhantomData,
80 ))
81 }
82}
83
84pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
86 pub async_graphql::BatchRequest,
87 PhantomData<R>,
88);
89
90impl<R> GraphQLBatchRequest<R> {
91 #[must_use]
93 pub fn into_inner(self) -> async_graphql::BatchRequest {
94 self.0
95 }
96}
97
98#[async_trait::async_trait]
99impl<S, B, R> FromRequest<S, B> for GraphQLBatchRequest<R>
100where
101 B: http_body::Body + Unpin + Send + Sync + 'static,
102 B::Data: Into<Bytes>,
103 B::Error: Into<BoxError>,
104 S: Send + Sync,
105 R: IntoResponse + From<ParseRequestError>,
106{
107 type Rejection = R;
108
109 async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
110 if let (&Method::GET, uri) = (req.method(), req.uri()) {
111 let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
112 .map_err(|err| {
113 ParseRequestError::Io(std::io::Error::new(
114 ErrorKind::Other,
115 format!("failed to parse graphql request from uri query: {}", err),
116 ))
117 });
118 Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
119 } else {
120 let content_type = req
121 .headers()
122 .get(http::header::CONTENT_TYPE)
123 .and_then(|value| value.to_str().ok())
124 .map(ToString::to_string);
125 let body_stream = BodyStream::from_request(req, state)
126 .await
127 .map_err(|_| {
128 ParseRequestError::Io(std::io::Error::new(
129 ErrorKind::Other,
130 "body has been taken by another extractor".to_string(),
131 ))
132 })?
133 .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
134 let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
135 Ok(Self(
136 async_graphql::http::receive_batch_body(
137 content_type,
138 body_reader,
139 MultipartOptions::default(),
140 )
141 .await?,
142 PhantomData,
143 ))
144 }
145 }
146}