1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
use std::{io::ErrorKind, marker::PhantomData};

use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
use axum::{
    extract::{FromRequest, Request},
    http::{self, Method},
    response::IntoResponse,
};
use tokio_util::compat::TokioAsyncReadCompatExt;

/// Extractor for GraphQL request.
pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
    pub async_graphql::Request,
    PhantomData<R>,
);

impl<R> GraphQLRequest<R> {
    /// Unwraps the value to `async_graphql::Request`.
    #[must_use]
    pub fn into_inner(self) -> async_graphql::Request {
        self.0
    }
}

/// Rejection response types.
pub mod rejection {
    use async_graphql::ParseRequestError;
    use axum::{
        body::Body,
        http,
        http::StatusCode,
        response::{IntoResponse, Response},
    };

    /// Rejection used for [`GraphQLRequest`](GraphQLRequest).
    pub struct GraphQLRejection(pub ParseRequestError);

    impl IntoResponse for GraphQLRejection {
        fn into_response(self) -> Response {
            match self.0 {
                ParseRequestError::PayloadTooLarge => http::Response::builder()
                    .status(StatusCode::PAYLOAD_TOO_LARGE)
                    .body(Body::empty())
                    .unwrap(),
                bad_request => http::Response::builder()
                    .status(StatusCode::BAD_REQUEST)
                    .body(Body::from(format!("{:?}", bad_request)))
                    .unwrap(),
            }
        }
    }

    impl From<ParseRequestError> for GraphQLRejection {
        fn from(err: ParseRequestError) -> Self {
            GraphQLRejection(err)
        }
    }
}

#[async_trait::async_trait]
impl<S, R> FromRequest<S> for GraphQLRequest<R>
where
    S: Send + Sync,
    R: IntoResponse + From<ParseRequestError>,
{
    type Rejection = R;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        Ok(GraphQLRequest(
            GraphQLBatchRequest::<R>::from_request(req, state)
                .await?
                .0
                .into_single()?,
            PhantomData,
        ))
    }
}

/// Extractor for GraphQL batch request.
pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
    pub async_graphql::BatchRequest,
    PhantomData<R>,
);

impl<R> GraphQLBatchRequest<R> {
    /// Unwraps the value to `async_graphql::BatchRequest`.
    #[must_use]
    pub fn into_inner(self) -> async_graphql::BatchRequest {
        self.0
    }
}

#[async_trait::async_trait]
impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
where
    S: Send + Sync,
    R: IntoResponse + From<ParseRequestError>,
{
    type Rejection = R;

    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
        if req.method() == Method::GET {
            let uri = req.uri();
            let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
                .map_err(|err| {
                    ParseRequestError::Io(std::io::Error::new(
                        ErrorKind::Other,
                        format!("failed to parse graphql request from uri query: {}", err),
                    ))
                });
            Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
        } else {
            let content_type = req
                .headers()
                .get(http::header::CONTENT_TYPE)
                .and_then(|value| value.to_str().ok())
                .map(ToString::to_string);
            let body_stream = req
                .into_body()
                .into_data_stream()
                .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
            Ok(Self(
                async_graphql::http::receive_batch_body(
                    content_type,
                    body_reader,
                    MultipartOptions::default(),
                )
                .await?,
                PhantomData,
            ))
        }
    }
}