#![cfg_attr(any(doc, test), doc = include_str!("../README.md"))]
#![cfg_attr(not(any(doc, test)), doc = env!("CARGO_PKG_NAME"))]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(
test,
expect(unused_crate_dependencies, reason = "examples and integration tests")
)]
mod for_minimal_versions_check_only {
use http_body_util as _;
#[cfg(test)]
use hyper_util as _;
}
mod response;
#[cfg(feature = "subscriptions")]
pub mod subscriptions;
use std::{collections::HashMap, str, sync::Arc};
use derive_more::with_trait::Display;
use juniper::{
ScalarValue,
http::{GraphQLBatchRequest, GraphQLRequest},
};
use tokio::task;
use warp::{
Filter,
body::{self, BodyDeserializeError},
http::{self, StatusCode},
hyper::body::Bytes,
query,
reject::{self, Reject, Rejection},
reply::{self, Reply},
};
use self::response::JuniperResponse;
pub fn make_graphql_filter<S, Query, Mutation, Subscription, CtxT, CtxErr>(
schema: impl Into<Arc<juniper::RootNode<Query, Mutation, Subscription, S>>>,
context_extractor: impl Filter<Extract = (CtxT,), Error = CtxErr> + Send + Sync + 'static,
) -> impl Filter<Extract = (reply::Response,), Error = Rejection> + Clone + Send
where
Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
CtxErr: Into<Rejection>,
S: ScalarValue + Send + Sync + 'static,
{
let schema = schema.into();
let context_extractor = context_extractor.boxed();
get_query_extractor::<S>()
.or(post_json_extractor::<S>())
.unify()
.or(post_graphql_extractor::<S>())
.unify()
.and(warp::any().map(move || schema.clone()))
.and(context_extractor)
.then(graphql_handler::<Query, Mutation, Subscription, CtxT, S>)
.recover(handle_rejects)
.unify()
}
pub fn make_graphql_filter_sync<S, Query, Mutation, Subscription, CtxT, CtxErr>(
schema: impl Into<Arc<juniper::RootNode<Query, Mutation, Subscription, S>>>,
context_extractor: impl Filter<Extract = (CtxT,), Error = CtxErr> + Send + Sync + 'static,
) -> impl Filter<Extract = (reply::Response,), Error = Rejection> + Clone + Send
where
Query: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
CtxErr: Into<Rejection>,
S: ScalarValue + Send + Sync + 'static,
{
let schema = schema.into();
let context_extractor = context_extractor.boxed();
get_query_extractor::<S>()
.or(post_json_extractor::<S>())
.unify()
.or(post_graphql_extractor::<S>())
.unify()
.and(warp::any().map(move || schema.clone()))
.and(context_extractor)
.then(graphql_handler_sync::<Query, Mutation, Subscription, CtxT, S>)
.recover(handle_rejects)
.unify()
}
async fn graphql_handler<Query, Mutation, Subscription, CtxT, S>(
req: GraphQLBatchRequest<S>,
schema: Arc<juniper::RootNode<Query, Mutation, Subscription, S>>,
context: CtxT,
) -> reply::Response
where
Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
let resp = req.execute(&*schema, &context).await;
JuniperResponse(resp).into_response()
}
async fn graphql_handler_sync<Query, Mutation, Subscription, CtxT, S>(
req: GraphQLBatchRequest<S>,
schema: Arc<juniper::RootNode<Query, Mutation, Subscription, S>>,
context: CtxT,
) -> reply::Response
where
Query: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLType<S, Context = CtxT> + Send + Sync + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
task::spawn_blocking(move || req.execute_sync(&*schema, &context))
.await
.map(|resp| JuniperResponse(resp).into_response())
.unwrap_or_else(|e| BlockingError(e).into_response())
}
fn post_json_extractor<S>()
-> impl Filter<Extract = (GraphQLBatchRequest<S>,), Error = Rejection> + Clone + Send
where
S: ScalarValue + Send,
{
warp::post().and(body::json())
}
fn post_graphql_extractor<S>()
-> impl Filter<Extract = (GraphQLBatchRequest<S>,), Error = Rejection> + Clone + Send
where
S: ScalarValue + Send,
{
warp::post()
.and(body::bytes())
.and_then(async |body: Bytes| {
let query = str::from_utf8(body.as_ref())
.map_err(|e| reject::custom(FilterError::NonUtf8Body(e)))?;
let req = GraphQLRequest::new(query.into(), None, None);
Ok::<GraphQLBatchRequest<S>, Rejection>(GraphQLBatchRequest::Single(req))
})
}
fn get_query_extractor<S>()
-> impl Filter<Extract = (GraphQLBatchRequest<S>,), Error = Rejection> + Clone + Send
where
S: ScalarValue + Send,
{
warp::get()
.and(query::query())
.and_then(async |mut qry: HashMap<String, String>| {
let req = GraphQLRequest::new(
qry.remove("query")
.ok_or_else(|| reject::custom(FilterError::MissingPathQuery))?,
qry.remove("operation_name"),
qry.remove("variables")
.map(|vs| serde_json::from_str(&vs))
.transpose()
.map_err(|e| reject::custom(FilterError::InvalidPathVariables(e)))?,
);
Ok::<GraphQLBatchRequest<S>, Rejection>(GraphQLBatchRequest::Single(req))
})
}
async fn handle_rejects(rej: Rejection) -> Result<reply::Response, Rejection> {
let (status, msg) = if let Some(e) = rej.find::<FilterError>() {
(StatusCode::BAD_REQUEST, e.to_string())
} else if let Some(e) = rej.find::<warp::reject::InvalidQuery>() {
(StatusCode::BAD_REQUEST, e.to_string())
} else if let Some(e) = rej.find::<BodyDeserializeError>() {
(StatusCode::BAD_REQUEST, e.to_string())
} else {
return Err(rej);
};
Ok(http::Response::builder()
.status(status)
.body(msg.into())
.unwrap())
}
#[derive(Debug, Display)]
enum FilterError {
#[display("Missing GraphQL `query` string in query parameters")]
MissingPathQuery,
#[display("Failed to deserialize GraphQL `variables` from JSON: {_0}")]
InvalidPathVariables(serde_json::Error),
#[display("Request body is not a valid UTF-8 string: {_0}")]
NonUtf8Body(str::Utf8Error),
}
impl Reject for FilterError {}
#[derive(Debug)]
struct BlockingError(task::JoinError);
impl Reply for BlockingError {
fn into_response(self) -> reply::Response {
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(format!("Failed to execute synchronous GraphQL request: {}", self.0).into())
.unwrap_or_else(|e| {
unreachable!("cannot build `reply::Response` out of `BlockingError`: {e}")
})
}
}
pub fn graphiql_filter(
graphql_endpoint_url: &'static str,
subscriptions_endpoint: Option<&'static str>,
) -> warp::filters::BoxedFilter<(http::Response<Vec<u8>>,)> {
warp::any()
.map(move || graphiql_response(graphql_endpoint_url, subscriptions_endpoint))
.boxed()
}
fn graphiql_response(
graphql_endpoint_url: &'static str,
subscriptions_endpoint: Option<&'static str>,
) -> http::Response<Vec<u8>> {
http::Response::builder()
.header("content-type", "text/html;charset=utf-8")
.body(
juniper::http::graphiql::graphiql_source(graphql_endpoint_url, subscriptions_endpoint)
.into_bytes(),
)
.expect("response is valid")
}
pub fn playground_filter(
graphql_endpoint_url: &'static str,
subscriptions_endpoint_url: Option<&'static str>,
) -> warp::filters::BoxedFilter<(http::Response<Vec<u8>>,)> {
warp::any()
.map(move || playground_response(graphql_endpoint_url, subscriptions_endpoint_url))
.boxed()
}
fn playground_response(
graphql_endpoint_url: &'static str,
subscriptions_endpoint_url: Option<&'static str>,
) -> http::Response<Vec<u8>> {
http::Response::builder()
.header("content-type", "text/html;charset=utf-8")
.body(
juniper::http::playground::playground_source(
graphql_endpoint_url,
subscriptions_endpoint_url,
)
.into_bytes(),
)
.expect("response is valid")
}
#[cfg(test)]
mod tests {
mod make_graphql_filter {
use std::future;
use juniper::{
EmptyMutation, EmptySubscription,
http::GraphQLBatchRequest,
tests::fixtures::starwars::schema::{Database, Query},
};
use warp::{
Filter as _, Reply, http,
reject::{self, Reject},
test::request,
};
use super::super::make_graphql_filter;
#[tokio::test]
async fn post_json() {
type Schema =
juniper::RootNode<Query, EmptyMutation<Database>, EmptySubscription<Database>>;
let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
let db = warp::any().map(Database::new);
let filter = warp::path("graphql2").and(make_graphql_filter(schema, db));
let response = request()
.method("POST")
.path("/graphql2")
.header("accept", "application/json")
.header("content-type", "application/json")
.body(r#"{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}"#)
.reply(&filter)
.await;
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json",
);
assert_eq!(
String::from_utf8(response.body().to_vec()).unwrap(),
r#"{"data":{"hero":{"name":"R2-D2"}}}"#,
);
}
#[tokio::test]
async fn rejects_fast_when_context_extractor_fails() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
#[derive(Clone, Copy, Debug)]
struct ExtractionError;
impl Reject for ExtractionError {}
impl warp::Reply for ExtractionError {
fn into_response(self) -> warp::reply::Response {
http::StatusCode::IM_A_TEAPOT.into_response()
}
}
type Schema =
juniper::RootNode<Query, EmptyMutation<Database>, EmptySubscription<Database>>;
let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
let is_called = Arc::new(AtomicBool::new(false));
let context_extractor = warp::any().and_then(move || {
future::ready(if is_called.swap(true, Ordering::Relaxed) {
Ok(Database::new())
} else {
Err(reject::custom(ExtractionError))
})
});
let filter = warp::path("graphql")
.and(make_graphql_filter(schema, context_extractor))
.recover(async |rejection: warp::reject::Rejection| {
rejection
.find::<ExtractionError>()
.map(|e| e.into_response())
.ok_or(rejection)
});
let resp = request()
.method("POST")
.path("/graphql")
.header("accept", "application/json")
.header("content-type", "application/json")
.body(r#"{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}"#)
.reply(&filter)
.await;
assert_eq!(
resp.status(),
http::StatusCode::IM_A_TEAPOT,
"response: {resp:#?}",
);
}
#[tokio::test]
async fn batch_requests() {
type Schema =
juniper::RootNode<Query, EmptyMutation<Database>, EmptySubscription<Database>>;
let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
let db = warp::any().map(Database::new);
let filter = warp::path("graphql2").and(make_graphql_filter(schema, db));
let response = request()
.method("POST")
.path("/graphql2")
.header("accept", "application/json")
.header("content-type", "application/json")
.body(
r#"[
{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"},
{"variables": null, "query": "{ hero(episode: EMPIRE) { id name } }"}
]"#,
)
.reply(&filter)
.await;
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
String::from_utf8(response.body().to_vec()).unwrap(),
r#"[{"data":{"hero":{"name":"R2-D2"}}},{"data":{"hero":{"id":"1000","name":"Luke Skywalker"}}}]"#,
);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json",
);
}
#[test]
fn batch_request_deserialization_can_fail() {
let json = r#"blah"#;
let result: Result<GraphQLBatchRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}
mod graphiql_filter {
use warp::{Filter as _, http, test::request};
use super::super::{graphiql_filter, graphiql_response};
#[test]
fn response_does_not_panic() {
graphiql_response("/abcd", None);
}
#[tokio::test]
async fn endpoint_matches() {
let filter = warp::get()
.and(warp::path("graphiql"))
.and(graphiql_filter("/graphql", None));
let result = request()
.method("GET")
.path("/graphiql")
.header("accept", "text/html")
.filter(&filter)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn returns_graphiql_source() {
let filter = warp::get()
.and(warp::path("dogs-api"))
.and(warp::path("graphiql"))
.and(graphiql_filter("/dogs-api/graphql", None));
let response = request()
.method("GET")
.path("/dogs-api/graphiql")
.header("accept", "text/html")
.reply(&filter)
.await;
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/html;charset=utf-8"
);
let body = String::from_utf8(response.body().to_vec()).unwrap();
assert!(body.contains("const JUNIPER_URL = '/dogs-api/graphql';"));
}
#[tokio::test]
async fn endpoint_with_subscription_matches() {
let filter = warp::get().and(warp::path("graphiql")).and(graphiql_filter(
"/graphql",
Some("ws:://localhost:8080/subscriptions"),
));
let result = request()
.method("GET")
.path("/graphiql")
.header("accept", "text/html")
.filter(&filter)
.await;
assert!(result.is_ok());
}
}
mod playground_filter {
use warp::{Filter as _, http, test::request};
use super::super::playground_filter;
#[tokio::test]
async fn endpoint_matches() {
let filter = warp::get()
.and(warp::path("playground"))
.and(playground_filter("/graphql", Some("/subscripitons")));
let result = request()
.method("GET")
.path("/playground")
.header("accept", "text/html")
.filter(&filter)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn returns_playground_source() {
let filter = warp::get()
.and(warp::path("dogs-api"))
.and(warp::path("playground"))
.and(playground_filter(
"/dogs-api/graphql",
Some("/dogs-api/subscriptions"),
));
let response = request()
.method("GET")
.path("/dogs-api/playground")
.header("accept", "text/html")
.reply(&filter)
.await;
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/html;charset=utf-8"
);
let body = String::from_utf8(response.body().to_vec()).unwrap();
assert!(body.contains(
"endpoint: '/dogs-api/graphql', subscriptionEndpoint: '/dogs-api/subscriptions'",
));
}
}
}