use std::ops::ControlFlow;
use std::sync::Arc;
use http::HeaderMap;
use http::StatusCode;
use http::header;
use schemars::JsonSchema;
use serde::Deserialize;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::register_plugin;
use crate::services::router;
#[derive(Deserialize, Debug, Clone, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct CSRFConfig {
unsafe_disabled: bool,
required_headers: Arc<Vec<String>>,
}
fn apollo_custom_preflight_headers() -> Arc<Vec<String>> {
Arc::new(vec![
"x-apollo-operation-name".to_string(),
"apollo-require-preflight".to_string(),
])
}
impl Default for CSRFConfig {
fn default() -> Self {
Self {
unsafe_disabled: false,
required_headers: apollo_custom_preflight_headers(),
}
}
}
static NON_PREFLIGHTED_CONTENT_TYPES: &[&str] = &[
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
];
#[derive(Debug, Clone)]
pub(crate) struct Csrf {
config: CSRFConfig,
}
#[async_trait::async_trait]
impl Plugin for Csrf {
type Config = CSRFConfig;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
Ok(Csrf {
config: init.config,
})
}
fn router_service(&self, service: router::BoxService) -> router::BoxService {
if !self.config.unsafe_disabled {
let required_headers = self.config.required_headers.clone();
ServiceBuilder::new()
.checkpoint(move |req: router::Request| {
if is_preflighted(&req, required_headers.as_slice()) {
tracing::trace!("request is preflighted");
Ok(ControlFlow::Continue(req))
} else {
tracing::trace!("request is not preflighted");
let error = crate::error::Error::builder().message(
format!(
"This operation has been blocked as a potential Cross-Site Request Forgery (CSRF). \
Please either specify a 'content-type' header (with a mime-type that is not one of {}) \
or provide one of the following headers: {}",
NON_PREFLIGHTED_CONTENT_TYPES.join(", "),
required_headers.join(", ")
))
.extension_code("CSRF_ERROR")
.build();
let res = router::Response::infallible_builder()
.error(error)
.status_code(StatusCode::BAD_REQUEST)
.context(req.context)
.build();
Ok(ControlFlow::Break(res))
}
})
.service(service)
.boxed()
} else {
service
}
}
}
fn is_preflighted(req: &router::Request, required_headers: &[String]) -> bool {
let headers = req.router_request.headers();
content_type_requires_preflight(headers)
|| recommended_header_is_provided(headers, required_headers)
}
fn content_type_requires_preflight(headers: &HeaderMap) -> bool {
let joined_content_type_header_value = if let Ok(combined_headers) = headers
.get_all(header::CONTENT_TYPE)
.iter()
.map(|header_value| {
header_value
.to_str()
.map(|as_str| as_str.trim().replace('\u{0009}', "\u{0020}")) })
.collect::<Result<Vec<_>, _>>()
{
combined_headers.join("\u{002C}\u{0020}") } else {
return false;
};
if let Ok(mime_type) = joined_content_type_header_value.parse::<mime::Mime>() {
!NON_PREFLIGHTED_CONTENT_TYPES.contains(&mime_type.essence_str())
} else {
false
}
}
fn recommended_header_is_provided(headers: &HeaderMap, required_headers: &[String]) -> bool {
required_headers
.iter()
.any(|header| headers.get(header).is_some())
}
register_plugin!("apollo", "csrf", Csrf);
#[cfg(test)]
mod csrf_tests {
#[tokio::test]
async fn plugin_registered() {
crate::plugin::plugins()
.find(|factory| factory.name == "apollo.csrf")
.expect("Plugin not found")
.create_instance_without_schema(&serde_json::json!({ "unsafe_disabled": true }))
.await
.unwrap();
crate::plugin::plugins()
.find(|factory| factory.name == "apollo.csrf")
.expect("Plugin not found")
.create_instance_without_schema(&serde_json::json!({}))
.await
.unwrap();
}
use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use mime::APPLICATION_JSON;
use super::*;
use crate::graphql;
use crate::plugins::test::PluginTestHarness;
#[tokio::test]
async fn it_lets_preflighted_request_pass_through() {
let with_preflight_content_type = router::Request::fake_builder()
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.build()
.unwrap();
assert_accepted(
include_str!("fixtures/default.router.yaml"),
with_preflight_content_type,
)
.await;
let with_preflight_header = router::Request::fake_builder()
.header("apollo-require-preflight", "this-is-a-test")
.build()
.unwrap();
assert_accepted(
include_str!("fixtures/default.router.yaml"),
with_preflight_header,
)
.await;
}
#[tokio::test]
async fn it_rejects_preflighted_multipart_form_data() {
let with_preflight_content_type = router::Request::fake_builder()
.header(CONTENT_TYPE, "multipart/form-data; boundary=842705fe5c26bcc3-e1302903b7efd762-d3aeccc8154e83c9-2ac7e6d91c6a7fdc")
.build()
.unwrap();
assert_rejected(
include_str!("fixtures/default.router.yaml"),
with_preflight_content_type,
)
.await;
}
#[tokio::test]
async fn it_rejects_non_preflighted_headers_request() {
let mut non_preflighted_request = router::Request::fake_builder().build().unwrap();
non_preflighted_request
.router_request
.headers_mut()
.remove("content-type");
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await
}
#[tokio::test]
async fn it_rejects_non_preflighted_content_type_request() {
let non_preflighted_request = router::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain")
.build()
.unwrap();
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await;
let non_preflighted_request = router::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain; charset=utf8")
.build()
.unwrap();
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await;
}
#[tokio::test]
async fn it_accepts_non_preflighted_headers_request_when_plugin_is_disabled() {
let non_preflighted_request = router::Request::fake_builder().build().unwrap();
assert_accepted(
include_str!("fixtures/unsafe_disabled.router.yaml"),
non_preflighted_request,
)
.await
}
#[tokio::test]
async fn it_rejects_non_preflighted_headers_request_when_required_headers_are_not_present() {
let non_preflighted_request = router::Request::fake_builder().build().unwrap();
assert_rejected(
include_str!("fixtures/required_headers.router.yaml"),
non_preflighted_request,
)
.await
}
#[tokio::test]
async fn it_accepts_non_preflighted_headers_request_when_required_headers_are_present() {
let non_preflighted_request = router::Request::fake_builder()
.header("X-MY-CSRF-Token", "this-is-a-test")
.build()
.unwrap();
assert_accepted(
include_str!("fixtures/required_headers.router.yaml"),
non_preflighted_request,
)
.await
}
async fn assert_accepted(config: &'static str, request: router::Request) {
let plugin = PluginTestHarness::<Csrf>::builder()
.config(config)
.build()
.await
.expect("test harness");
let router_service =
plugin.router_service(|_r| async { router::Response::fake_builder().build() });
let mut resp = router_service
.call(request)
.await
.expect("expected response");
let body = resp
.response
.body_mut()
.collect()
.await
.expect("expected body");
let response: graphql::Response = serde_json::from_slice(&body.to_bytes()).unwrap();
assert_eq!(response.errors.len(), 0);
}
async fn assert_rejected(config: &'static str, request: router::Request) {
let plugin = PluginTestHarness::<Csrf>::builder()
.config(config)
.build()
.await
.expect("test harness");
let router_service =
plugin.router_service(|_r| async { router::Response::fake_builder().build() });
let mut resp = router_service
.call(request)
.await
.expect("expected response");
let body = resp
.response
.body_mut()
.collect()
.await
.expect("expected body");
let response: graphql::Response = serde_json::from_slice(&body.to_bytes()).unwrap();
assert_eq!(response.errors.len(), 1);
assert_eq!(
response.errors[0]
.extensions
.get("code")
.expect("error code")
.as_str(),
Some("CSRF_ERROR")
);
}
}