use async_graphql::BatchRequest as GqlBatchRequest;
use async_graphql::http::MultipartOptions;
use http::StatusCode;
use http_body_util::BodyExt;
use crate::extractors::FromRequest;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub struct GraphQLRequest(pub async_graphql::Request);
impl GraphQLRequest {
pub fn into_inner(self) -> async_graphql::Request {
self.0
}
}
pub struct GraphQLBatchRequest(pub GqlBatchRequest);
impl GraphQLBatchRequest {
pub fn into_inner(self) -> GqlBatchRequest {
self.0
}
}
pub const MAX_GRAPHQL_BODY_SIZE: usize = 4 * 1024 * 1024;
#[derive(Debug)]
pub enum GraphQLError {
MissingQuery,
BodyRead(String),
BodyTooLarge,
InvalidJson(String),
Parse(String),
UnsupportedMediaType(String),
}
#[derive(Clone, Default)]
pub struct GraphQLOptions {
pub multipart: MultipartOptions,
}
impl Responder for GraphQLError {
fn into_response(self) -> Response {
match self {
GraphQLError::MissingQuery => {
(StatusCode::BAD_REQUEST, "Missing GraphQL query").into_response()
}
GraphQLError::BodyRead(e) => {
(StatusCode::BAD_REQUEST, format!("Failed to read body: {e}")).into_response()
}
GraphQLError::BodyTooLarge => (
StatusCode::PAYLOAD_TOO_LARGE,
format!("GraphQL body exceeds {MAX_GRAPHQL_BODY_SIZE} bytes"),
)
.into_response(),
GraphQLError::InvalidJson(e) => {
(StatusCode::BAD_REQUEST, format!("Invalid JSON: {e}")).into_response()
}
GraphQLError::Parse(e) => {
(StatusCode::BAD_REQUEST, format!("Invalid request: {e}")).into_response()
}
GraphQLError::UnsupportedMediaType(ct) => (
StatusCode::UNSUPPORTED_MEDIA_TYPE,
format!("Unsupported GraphQL content-type: {ct}"),
)
.into_response(),
}
}
}
fn classify_graphql_content_type(ct: Option<&str>) -> Result<GraphQLBodyKind, GraphQLError> {
let raw = ct.unwrap_or("").trim();
if raw.is_empty() {
return Err(GraphQLError::UnsupportedMediaType("<missing>".to_string()));
}
let essence = raw
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
match essence.as_str() {
"application/json" => Ok(GraphQLBodyKind::Json),
"application/graphql" | "application/graphql-response+json" => Ok(GraphQLBodyKind::Graphql),
"multipart/form-data" => Ok(GraphQLBodyKind::Multipart),
_ => Err(GraphQLError::UnsupportedMediaType(raw.to_string())),
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum GraphQLBodyKind {
Json,
Graphql,
Multipart,
}
#[inline]
fn resolve_opts(req: &Request) -> MultipartOptions {
if let Some(opts) = req.extensions().get::<GraphQLOptions>() {
return opts.multipart;
}
if let Some(global) = crate::state::get_state::<GraphQLOptions>() {
return global.as_ref().multipart;
}
MultipartOptions::default()
}
fn parse_get_request(req: &Request) -> Result<async_graphql::Request, GraphQLError> {
let qs = req.uri().query().unwrap_or("");
async_graphql::http::parse_query_string(qs).map_err(|e| GraphQLError::Parse(e.to_string()))
}
async fn read_body_bytes(req: &mut Request) -> Result<bytes::Bytes, GraphQLError> {
if let Some(cl) = req.headers().get(http::header::CONTENT_LENGTH)
&& let Some(n) = cl.to_str().ok().and_then(|s| s.parse::<usize>().ok())
&& n > MAX_GRAPHQL_BODY_SIZE
{
return Err(GraphQLError::BodyTooLarge);
}
let body = std::mem::take(req.body_mut());
let limited = http_body_util::Limited::new(body, MAX_GRAPHQL_BODY_SIZE);
match limited.collect().await {
Ok(c) => Ok(c.to_bytes()),
Err(e) => {
if e
.downcast_ref::<http_body_util::LengthLimitError>()
.is_some()
{
Err(GraphQLError::BodyTooLarge)
} else {
Err(GraphQLError::BodyRead(e.to_string()))
}
}
}
}
impl<'a> FromRequest<'a> for GraphQLRequest {
type Error = GraphQLError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
if req.method() == http::Method::GET {
return Ok(GraphQLRequest(parse_get_request(req)?));
}
let opts = resolve_opts(req);
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
classify_graphql_content_type(content_type.as_deref())?;
let body = read_body_bytes(req).await?;
if body.is_empty() {
return Err(GraphQLError::Parse("empty request body".to_string()));
}
let reader = futures_util::io::Cursor::new(body.to_vec());
let req = async_graphql::http::receive_body(content_type.as_deref(), reader, opts)
.await
.map_err(|e| GraphQLError::Parse(e.to_string()))?;
Ok(GraphQLRequest(req))
}
}
}
pub fn attach_graphql_options(req: &mut Request, opts: GraphQLOptions) {
req.extensions_mut().insert(opts);
}
pub fn set_global_graphql_options(opts: GraphQLOptions) {
crate::state::set_state::<GraphQLOptions>(opts);
}
pub async fn receive_graphql(
req: &mut Request,
opts: MultipartOptions,
) -> Result<async_graphql::Request, GraphQLError> {
if req.method() == http::Method::GET {
return parse_get_request(req);
}
let body = read_body_bytes(req).await?;
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
let reader = futures_util::io::Cursor::new(body.to_vec());
async_graphql::http::receive_body(content_type.as_deref(), reader, opts)
.await
.map_err(|e| GraphQLError::Parse(e.to_string()))
}
pub async fn receive_graphql_batch(
req: &mut Request,
opts: MultipartOptions,
) -> Result<GqlBatchRequest, GraphQLError> {
if req.method() == http::Method::GET {
let single = parse_get_request(req)?;
return Ok(GqlBatchRequest::Single(single));
}
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
classify_graphql_content_type(content_type.as_deref())?;
let body = read_body_bytes(req).await?;
if body.is_empty() {
return Err(GraphQLError::Parse("empty request body".to_string()));
}
let reader = futures_util::io::Cursor::new(body.to_vec());
async_graphql::http::receive_batch_body(content_type.as_deref(), reader, opts)
.await
.map_err(|e| GraphQLError::Parse(e.to_string()))
}
impl<'a> FromRequest<'a> for GraphQLBatchRequest {
type Error = GraphQLError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
if req.method() == http::Method::GET {
let single = parse_get_request(req)?;
return Ok(GraphQLBatchRequest(GqlBatchRequest::Single(single)));
}
let opts = resolve_opts(req);
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
classify_graphql_content_type(content_type.as_deref())?;
let body = read_body_bytes(req).await?;
if body.is_empty() {
return Err(GraphQLError::Parse("empty request body".to_string()));
}
let reader = futures_util::io::Cursor::new(body.to_vec());
let batch = async_graphql::http::receive_batch_body(content_type.as_deref(), reader, opts)
.await
.map_err(|e| GraphQLError::Parse(e.to_string()))?;
Ok(GraphQLBatchRequest(batch))
}
}
}