use std::ops::ControlFlow;
use http::HeaderMap;
use http::Method;
use http::StatusCode;
use http::header::ACCEPT;
use http::header::CONTENT_TYPE;
use mediatype::MediaTypeList;
use mediatype::ReadParams;
use mediatype::names::_STAR;
use mediatype::names::APPLICATION;
use mediatype::names::JSON;
use mediatype::names::MIXED;
use mediatype::names::MULTIPART;
use mime::APPLICATION_JSON;
use tower::BoxError;
use tower::Layer;
use tower::Service;
use tower::ServiceExt;
use crate::graphql;
use crate::layers::ServiceExt as _;
use crate::layers::sync_checkpoint::CheckpointService;
use crate::services::APPLICATION_JSON_HEADER_VALUE;
use crate::services::MULTIPART_DEFER_ACCEPT;
use crate::services::MULTIPART_DEFER_SPEC_PARAMETER;
use crate::services::MULTIPART_DEFER_SPEC_VALUE;
use crate::services::MULTIPART_SUBSCRIPTION_ACCEPT;
use crate::services::MULTIPART_SUBSCRIPTION_SPEC_PARAMETER;
use crate::services::MULTIPART_SUBSCRIPTION_SPEC_VALUE;
use crate::services::router;
use crate::services::router::ClientRequestAccepts;
use crate::services::router::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
use crate::services::router::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
use crate::services::supergraph;
pub(crate) const GRAPHQL_JSON_RESPONSE_HEADER_VALUE: &str = "application/graphql-response+json";
#[derive(Clone, Default)]
pub(crate) struct RouterLayer {}
impl<S> Layer<S> for RouterLayer
where
S: Service<router::Request, Response = router::Response, Error = BoxError> + Send + 'static,
<S as Service<router::Request>>::Future: Send + 'static,
{
type Service = CheckpointService<S, router::Request>;
fn layer(&self, service: S) -> Self::Service {
CheckpointService::new(
move |req| {
if req.router_request.method() != Method::GET
&& !content_type_is_json(req.router_request.headers())
{
let response: http::Response<crate::services::router::Body> = http::Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.body(crate::services::router::Body::from(
serde_json::json!({
"errors": [
graphql::Error::builder()
.message(format!(
r#"'content-type' header must be one of: {:?} or {:?}"#,
APPLICATION_JSON.essence_str(),
GRAPHQL_JSON_RESPONSE_HEADER_VALUE,
))
.extension_code("INVALID_CONTENT_TYPE_HEADER")
.build()
]
})
.to_string(),
))
.expect("cannot fail");
u64_counter!(
"apollo_router_http_requests_total",
"Total number of HTTP requests made. (deprecated)",
1,
status = StatusCode::UNSUPPORTED_MEDIA_TYPE.as_u16() as i64,
error = format!(
r#"'content-type' header must be one of: {:?} or {:?}"#,
APPLICATION_JSON.essence_str(),
GRAPHQL_JSON_RESPONSE_HEADER_VALUE,
)
);
return Ok(ControlFlow::Break(response.into()));
}
if req.router_request.method() == Method::GET
&& !content_type_is_strictly_json_or_missing(req.router_request.headers())
{
let response: http::Response<crate::services::router::Body> = http::Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.body(crate::services::router::Body::from(
serde_json::json!({
"errors": [
graphql::Error::builder()
.message(format!("GET request 'content-type' header may only contain: {:?}", APPLICATION_JSON.essence_str()))
.extension_code("INVALID_CONTENT_TYPE_HEADER")
.build()
]
})
.to_string(),
))
.expect("cannot fail");
u64_counter!(
"apollo_router_http_requests_total",
"Total number of HTTP requests made. (deprecated)",
1,
status = StatusCode::UNSUPPORTED_MEDIA_TYPE.as_u16() as i64,
error = format!(
"GET request 'content-type' header may only contain: {:?}",
APPLICATION_JSON.essence_str()
)
);
return Ok(ControlFlow::Break(response.into()));
}
let accepts = parse_accept(req.router_request.headers());
if accepts.wildcard
|| accepts.multipart_defer
|| accepts.multipart_subscription
|| accepts.json
{
req.context
.extensions()
.with_lock(|mut lock| lock.insert(accepts));
Ok(ControlFlow::Continue(req))
} else {
let response: http::Response<hyper::Body> = http::Response::builder().status(StatusCode::NOT_ACCEPTABLE).header(CONTENT_TYPE, APPLICATION_JSON.essence_str()).body(
hyper::Body::from(
serde_json::json!({
"errors": [
graphql::Error::builder()
.message(format!(
r#"'accept' header must be one of: \"*/*\", {:?}, {:?}, {:?} or {:?}"#,
APPLICATION_JSON.essence_str(),
GRAPHQL_JSON_RESPONSE_HEADER_VALUE,
MULTIPART_SUBSCRIPTION_ACCEPT,
MULTIPART_DEFER_ACCEPT
))
.extension_code("INVALID_ACCEPT_HEADER")
.build()
]
}).to_string())).expect("cannot fail");
Ok(ControlFlow::Break(response.into()))
}
},
service,
)
}
}
#[derive(Clone, Default)]
pub(crate) struct SupergraphLayer {}
impl<S> Layer<S> for SupergraphLayer
where
S: Service<supergraph::Request, Response = supergraph::Response, Error = BoxError>
+ Send
+ 'static,
<S as Service<supergraph::Request>>::Future: Send + 'static,
{
type Service = supergraph::BoxService;
fn layer(&self, service: S) -> Self::Service {
service
.map_first_graphql_response(|context, mut parts, res| {
let ClientRequestAccepts {
wildcard: accepts_wildcard,
json: accepts_json,
multipart_defer: accepts_multipart_defer,
multipart_subscription: accepts_multipart_subscription,
} = context.extensions().with_lock(|lock| {
lock.get::<ClientRequestAccepts>()
.cloned()
.unwrap_or_default()
});
if !res.has_next.unwrap_or_default() && (accepts_json || accepts_wildcard) {
parts
.headers
.insert(CONTENT_TYPE, APPLICATION_JSON_HEADER_VALUE.clone());
} else if accepts_multipart_defer {
parts.headers.insert(
CONTENT_TYPE,
MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE.clone(),
);
} else if accepts_multipart_subscription {
parts.headers.insert(
CONTENT_TYPE,
MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE.clone(),
);
}
(parts, res)
})
.boxed()
}
}
fn content_type_is_strictly_json_or_missing(headers: &HeaderMap) -> bool {
for header_value in headers.get_all(CONTENT_TYPE) {
let Ok(content_type_str) = header_value.to_str() else {
return false;
};
let mime_results = MediaTypeList::new(content_type_str);
for mime_result in mime_results {
let Ok(mime) = mime_result else { return false };
if !(mime.ty == APPLICATION && mime.subty == JSON) {
return false;
}
}
}
true
}
fn content_type_is_json(headers: &HeaderMap) -> bool {
headers.get_all(CONTENT_TYPE).iter().any(|value| {
value
.to_str()
.map(|accept_str| {
let mut list = MediaTypeList::new(accept_str);
list.any(|mime| {
mime.as_ref()
.map(|mime| {
(mime.ty == APPLICATION && mime.subty == JSON)
|| (mime.ty == APPLICATION
&& mime.subty.as_str() == "graphql-response"
&& mime.suffix == Some(JSON))
})
.unwrap_or(false)
})
})
.unwrap_or(false)
})
}
#[allow(clippy::manual_flatten)]
fn parse_accept(headers: &HeaderMap) -> ClientRequestAccepts {
let mut header_present = false;
let mut accepts = ClientRequestAccepts::default();
for value in headers.get_all(ACCEPT) {
header_present = true;
if let Ok(str) = value.to_str() {
for result in MediaTypeList::new(str) {
if let Ok(mime) = result {
if !accepts.json
&& ((mime.ty == APPLICATION && mime.subty == JSON)
|| (mime.ty == APPLICATION
&& mime.subty.as_str() == "graphql-response"
&& mime.suffix == Some(JSON)))
{
accepts.json = true
}
if !accepts.wildcard && (mime.ty == _STAR && mime.subty == _STAR) {
accepts.wildcard = true
}
if !accepts.multipart_defer && (mime.ty == MULTIPART && mime.subty == MIXED) {
let parameter = mediatype::Name::new(MULTIPART_DEFER_SPEC_PARAMETER)
.expect("valid name");
let value =
mediatype::Value::new(MULTIPART_DEFER_SPEC_VALUE).expect("valid value");
if mime.get_param(parameter) == Some(value) {
accepts.multipart_defer = true
}
}
if !accepts.multipart_subscription
&& (mime.ty == MULTIPART && mime.subty == MIXED)
{
let parameter = mediatype::Name::new(MULTIPART_SUBSCRIPTION_SPEC_PARAMETER)
.expect("valid name");
let value = mediatype::Value::new(MULTIPART_SUBSCRIPTION_SPEC_VALUE)
.expect("valid value");
if mime.get_param(parameter) == Some(value) {
accepts.multipart_subscription = true
}
}
}
}
}
}
if !header_present {
accepts.json = true
}
accepts
}
#[cfg(test)]
mod tests {
use http::HeaderValue;
use super::*;
#[test]
fn content_type_is_strictly_json_or_missing_accepts_valid_headers() {
assert!(content_type_is_strictly_json_or_missing(&HeaderMap::new()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
assert!(content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static(""));
assert!(content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
assert!(content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/json; charset=utf-8"),
);
assert!(content_type_is_strictly_json_or_missing(&headers));
}
#[test]
fn content_type_is_strictly_json_or_missing_rejects_invalid_headers() {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("invalid"));
assert!(!content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
assert!(!content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("multipart/form-data"),
);
assert!(!content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/graphql"),
);
assert!(!content_type_is_strictly_json_or_missing(&headers));
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.append(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
assert!(!content_type_is_strictly_json_or_missing(&headers));
}
#[test]
fn it_checks_accept_header() {
let mut default_headers = HeaderMap::new();
default_headers.insert(
ACCEPT,
HeaderValue::from_static(APPLICATION_JSON.essence_str()),
);
default_headers.append(ACCEPT, HeaderValue::from_static("foo/bar"));
let accepts = parse_accept(&default_headers);
assert!(accepts.json);
let mut default_headers = HeaderMap::new();
default_headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
default_headers.append(ACCEPT, HeaderValue::from_static("foo/bar"));
let accepts = parse_accept(&default_headers);
assert!(accepts.wildcard);
let mut default_headers = HeaderMap::new();
default_headers.insert(ACCEPT, HeaderValue::from_static("text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8"));
let accepts = parse_accept(&default_headers);
assert!(accepts.wildcard);
let mut default_headers = HeaderMap::new();
default_headers.insert(
ACCEPT,
HeaderValue::from_static(GRAPHQL_JSON_RESPONSE_HEADER_VALUE),
);
default_headers.append(ACCEPT, HeaderValue::from_static("foo/bar"));
let accepts = parse_accept(&default_headers);
assert!(accepts.json);
let mut default_headers = HeaderMap::new();
default_headers.insert(
ACCEPT,
HeaderValue::from_static(GRAPHQL_JSON_RESPONSE_HEADER_VALUE),
);
default_headers.append(ACCEPT, HeaderValue::from_static(MULTIPART_DEFER_ACCEPT));
let accepts = parse_accept(&default_headers);
assert!(accepts.multipart_defer);
let mut default_headers = HeaderMap::new();
default_headers.insert(
ACCEPT,
HeaderValue::from_static("multipart/mixed;subscriptionSpec=1.0, application/json"),
);
let accepts = parse_accept(&default_headers);
assert!(accepts.multipart_subscription);
}
}