rusty_gql_axum/
request.rs

1use axum::extract::{BodyStream, FromRequest};
2use axum::http::{Method, StatusCode};
3use axum::response::{IntoResponse, Response};
4use axum::{body, BoxError};
5use bytes::Bytes;
6use futures_util::TryStreamExt;
7use rusty_gql::{receive_http_request, HttpRequestError};
8use tokio_util::compat::TokioAsyncReadCompatExt;
9
10pub struct GqlRequest(pub rusty_gql::Request);
11
12pub struct GqlRejection(pub HttpRequestError);
13
14impl IntoResponse for GqlRejection {
15    fn into_response(self) -> Response {
16        let body = body::boxed(body::Full::from(format!("{:?}", self.0)));
17        Response::builder()
18            .status(StatusCode::BAD_REQUEST)
19            .body(body)
20            .unwrap()
21    }
22}
23
24impl From<HttpRequestError> for GqlRejection {
25    fn from(error: HttpRequestError) -> Self {
26        GqlRejection(error)
27    }
28}
29
30#[async_trait::async_trait]
31impl<B> FromRequest<B> for GqlRequest
32where
33    B: http_body::Body + Unpin + Send + Sync + 'static,
34    B::Data: Into<Bytes>,
35    B::Error: Into<BoxError>,
36{
37    type Rejection = GqlRejection;
38    async fn from_request(
39        req: &mut axum::extract::RequestParts<B>,
40    ) -> Result<Self, Self::Rejection> {
41        if let (&Method::GET, uri) = (req.method(), req.uri()) {
42            let res = serde_urlencoded::from_str(uri.query().unwrap_or_default()).map_err(|err| {
43                HttpRequestError::Io(std::io::Error::new(
44                    std::io::ErrorKind::Other,
45                    format!("failed to parse graphql requst from query params: {}", err),
46                ))
47            });
48            Ok(Self(res?))
49        } else {
50            let body_stream = BodyStream::from_request(req)
51                .await
52                .map_err(|err| {
53                    HttpRequestError::Io(std::io::Error::new(
54                        std::io::ErrorKind::Other,
55                        err.to_string(),
56                    ))
57                })?
58                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()));
59            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
60            Ok(Self(receive_http_request(body_reader).await?))
61        }
62    }
63}