async_graphql_axum/
query.rs

1use std::{
2    convert::Infallible,
3    task::{Context, Poll},
4    time::Duration,
5};
6
7use async_graphql::{
8    Executor,
9    http::{create_multipart_mixed_stream, is_accept_multipart_mixed},
10};
11use axum::{
12    BoxError,
13    body::{Body, HttpBody},
14    extract::FromRequest,
15    http::{Request as HttpRequest, Response as HttpResponse},
16    response::IntoResponse,
17};
18use bytes::Bytes;
19use futures_util::{StreamExt, future::BoxFuture};
20use tower_service::Service;
21
22use crate::{
23    GraphQLBatchRequest, GraphQLRequest, GraphQLResponse, extract::rejection::GraphQLRejection,
24};
25
26/// A GraphQL service.
27#[derive(Clone)]
28pub struct GraphQL<E> {
29    executor: E,
30}
31
32impl<E> GraphQL<E> {
33    /// Create a GraphQL handler.
34    pub fn new(executor: E) -> Self {
35        Self { executor }
36    }
37}
38
39impl<B, E> Service<HttpRequest<B>> for GraphQL<E>
40where
41    B: HttpBody<Data = Bytes> + Send + 'static,
42    B::Data: Into<Bytes>,
43    B::Error: Into<BoxError>,
44    E: Executor,
45{
46    type Response = HttpResponse<Body>;
47    type Error = Infallible;
48    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
49
50    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51        Poll::Ready(Ok(()))
52    }
53
54    fn call(&mut self, req: HttpRequest<B>) -> Self::Future {
55        let executor = self.executor.clone();
56        let req = req.map(Body::new);
57        Box::pin(async move {
58            let is_accept_multipart_mixed = req
59                .headers()
60                .get("accept")
61                .and_then(|value| value.to_str().ok())
62                .map(is_accept_multipart_mixed)
63                .unwrap_or_default();
64
65            if is_accept_multipart_mixed {
66                let req = match GraphQLRequest::<GraphQLRejection>::from_request(req, &()).await {
67                    Ok(req) => req,
68                    Err(err) => return Ok(err.into_response()),
69                };
70                let stream = executor.execute_stream(req.0, None);
71                let body = Body::from_stream(
72                    create_multipart_mixed_stream(stream, Duration::from_secs(30))
73                        .map(Ok::<_, std::io::Error>),
74                );
75                Ok(HttpResponse::builder()
76                    .header("content-type", "multipart/mixed; boundary=graphql")
77                    .body(body)
78                    .expect("BUG: invalid response"))
79            } else {
80                let req =
81                    match GraphQLBatchRequest::<GraphQLRejection>::from_request(req, &()).await {
82                        Ok(req) => req,
83                        Err(err) => return Ok(err.into_response()),
84                    };
85                Ok(GraphQLResponse(executor.execute_batch(req.0).await).into_response())
86            }
87        })
88    }
89}