use std::fmt;
use axum::{
Json, RequestExt as _,
body::Body,
extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode, header},
response::{IntoResponse as _, Response},
};
use juniper::{
DefaultScalarValue, ScalarValue,
http::{GraphQLBatchRequest, GraphQLRequest},
};
use serde::Deserialize;
#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
S: ScalarValue;
impl<S, State> FromRequest<State> for JuniperRequest<S>
where
S: ScalarValue,
State: Sync,
Query<GetRequest>: FromRequestParts<State>,
Json<GraphQLBatchRequest<S>>: FromRequest<State>,
<Json<GraphQLBatchRequest<S>> as FromRequest<State>>::Rejection: fmt::Display,
String: FromRequest<State>,
{
type Rejection = Response;
async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
let content_type = req
.headers()
.get(header::CONTENT_TYPE)
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"`Content-Type` header is not a valid HTTP header string",
)
.into_response()
})?;
match (req.method(), content_type) {
(&Method::GET, _) => req
.extract_parts::<Query<GetRequest>>()
.await
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query string: {e}"),
)
.into_response()
})
.and_then(|query| {
query
.0
.try_into()
.map(|q| Self(GraphQLBatchRequest::Single(q)))
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query `variables`: {e}"),
)
.into_response()
})
}),
(&Method::POST, Some(x)) if x.starts_with("application/json") => {
Json::<GraphQLBatchRequest<S>>::from_request(req, state)
.await
.map(|req| Self(req.0))
.map_err(|e| {
(StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
})
}
(&Method::POST, Some(x)) if x.starts_with("application/graphql") => {
String::from_request(req, state)
.await
.map(|body| {
Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
body, None, None,
)))
})
.map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response())
}
(&Method::POST, _) => Err((
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"`Content-Type` header is expected to be either `application/json` or \
`application/graphql`",
)
.into_response()),
_ => Err((
StatusCode::METHOD_NOT_ALLOWED,
"HTTP method is expected to be either GET or POST",
)
.into_response()),
}
}
}
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct GetRequest {
query: String,
#[serde(rename = "operationName")]
operation_name: Option<String>,
variables: Option<String>,
}
impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
type Error = serde_json::Error;
fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
let GetRequest {
query,
operation_name,
variables,
} = req;
Ok(Self::new(
query,
operation_name,
variables.map(|v| serde_json::from_str(&v)).transpose()?,
))
}
}
#[cfg(test)]
mod juniper_request_tests {
use axum::{body::Body, extract::FromRequest as _, http::Request};
use futures::TryStreamExt as _;
use juniper::{
graphql_input_value,
http::{GraphQLBatchRequest, GraphQLRequest},
};
use super::JuniperRequest;
#[tokio::test]
async fn from_get_request() {
let req = Request::get(&format!(
"/?query={}",
urlencoding::encode("{ add(a: 2, b: 3) }")
))
.body(Body::empty())
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".into(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_get_request_with_variables() {
let req = Request::get(&format!(
"/?query={}&variables={}",
urlencoding::encode(
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }",
),
urlencoding::encode(r#"{"id": "1000"}"#),
))
.body(Body::empty())
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }".into(),
None,
Some(graphql_input_value!({"id": "1000"})),
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_json_post_request() {
let req = Request::post("/")
.header("content-type", "application/json")
.body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_json_post_request_with_charset() {
let req = Request::post("/")
.header("content-type", "application/json; charset=utf-8")
.body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_graphql_post_request() {
let req = Request::post("/")
.header("content-type", "application/graphql")
.body(Body::from(r#"{ add(a: 2, b: 3) }"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_graphql_post_request_with_charset() {
let req = Request::post("/")
.header("content-type", "application/graphql; charset=utf-8")
.body(Body::from(r#"{ add(a: 2, b: 3) }"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
async fn do_from_request(req: Request<Body>) -> JuniperRequest {
match JuniperRequest::from_request(req, &()).await {
Ok(resp) => resp,
Err(resp) => {
panic!(
"`JuniperRequest::from_request()` failed with `{}` status and body:\n{}",
resp.status(),
display_body(resp.into_body()).await,
)
}
}
}
async fn display_body(body: Body) -> String {
String::from_utf8(
body.into_data_stream()
.map_ok(|bytes| bytes.to_vec())
.try_concat()
.await
.unwrap(),
)
.unwrap_or_else(|e| panic!("not UTF-8 body: {e}"))
}
}