use std::{
convert::Infallible,
task::{Context, Poll},
time::Duration,
};
use async_graphql::{
http::{create_multipart_mixed_stream, is_accept_multipart_mixed},
Executor,
};
use axum::{
body::{BoxBody, HttpBody, StreamBody},
extract::FromRequest,
http::{Request as HttpRequest, Response as HttpResponse},
response::IntoResponse,
BoxError,
};
use bytes::Bytes;
use futures_util::{future::BoxFuture, StreamExt};
use tower_service::Service;
use crate::{
extract::rejection::GraphQLRejection, GraphQLBatchRequest, GraphQLRequest, GraphQLResponse,
};
#[derive(Clone)]
pub struct GraphQL<E> {
executor: E,
}
impl<E> GraphQL<E> {
pub fn new(executor: E) -> Self {
Self { executor }
}
}
impl<B, E> Service<HttpRequest<B>> for GraphQL<E>
where
B: HttpBody + Send + Sync + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
E: Executor,
{
type Response = HttpResponse<BoxBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: HttpRequest<B>) -> Self::Future {
let executor = self.executor.clone();
Box::pin(async move {
let is_accept_multipart_mixed = req
.headers()
.get("accept")
.and_then(|value| value.to_str().ok())
.map(is_accept_multipart_mixed)
.unwrap_or_default();
if is_accept_multipart_mixed {
let req = match GraphQLRequest::<GraphQLRejection>::from_request(req, &()).await {
Ok(req) => req,
Err(err) => return Ok(err.into_response()),
};
let stream = executor.execute_stream(req.0, None);
let body = StreamBody::new(
create_multipart_mixed_stream(
stream,
tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(
Duration::from_secs(30),
))
.map(|_| ()),
)
.map(Ok::<_, std::io::Error>),
);
Ok(HttpResponse::builder()
.header("content-type", "multipart/mixed; boundary=graphql")
.body(body.boxed_unsync())
.expect("BUG: invalid response"))
} else {
let req =
match GraphQLBatchRequest::<GraphQLRejection>::from_request(req, &()).await {
Ok(req) => req,
Err(err) => return Ok(err.into_response()),
};
Ok(GraphQLResponse(executor.execute_batch(req.0).await).into_response())
}
})
}
}