use std::ops::ControlFlow;
use std::sync::Arc;
use http::HeaderMap;
use http::StatusCode;
use http::header;
use schemars::JsonSchema;
use serde::Deserialize;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use crate::layers::ServiceBuilderExt;
use crate::plugin::Plugin;
use crate::plugin::PluginInit;
use crate::register_plugin;
use crate::services::SupergraphResponse;
use crate::services::supergraph;
#[derive(Deserialize, Debug, Clone, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct CSRFConfig {
unsafe_disabled: bool,
required_headers: Arc<Vec<String>>,
}
fn apollo_custom_preflight_headers() -> Arc<Vec<String>> {
Arc::new(vec![
"x-apollo-operation-name".to_string(),
"apollo-require-preflight".to_string(),
])
}
impl Default for CSRFConfig {
fn default() -> Self {
Self {
unsafe_disabled: false,
required_headers: apollo_custom_preflight_headers(),
}
}
}
static NON_PREFLIGHTED_CONTENT_TYPES: &[&str] = &[
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
];
#[derive(Debug, Clone)]
pub(crate) struct Csrf {
config: CSRFConfig,
}
#[async_trait::async_trait]
impl Plugin for Csrf {
type Config = CSRFConfig;
async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
Ok(Csrf {
config: init.config,
})
}
fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
if !self.config.unsafe_disabled {
let required_headers = self.config.required_headers.clone();
ServiceBuilder::new()
.checkpoint(move |req: supergraph::Request| {
if is_preflighted(&req, required_headers.as_slice()) {
tracing::trace!("request is preflighted");
Ok(ControlFlow::Continue(req))
} else {
tracing::trace!("request is not preflighted");
let error = crate::error::Error::builder().message(
format!(
"This operation has been blocked as a potential Cross-Site Request Forgery (CSRF). \
Please either specify a 'content-type' header (with a mime-type that is not one of {}) \
or provide one of the following headers: {}",
NON_PREFLIGHTED_CONTENT_TYPES.join(", "),
required_headers.join(", ")
))
.extension_code("CSRF_ERROR")
.build();
let res = SupergraphResponse::infallible_builder()
.error(error)
.status_code(StatusCode::BAD_REQUEST)
.context(req.context)
.build();
Ok(ControlFlow::Break(res))
}
})
.service(service)
.boxed()
} else {
service
}
}
}
fn is_preflighted(req: &supergraph::Request, required_headers: &[String]) -> bool {
let headers = req.supergraph_request.headers();
content_type_requires_preflight(headers)
|| recommended_header_is_provided(headers, required_headers)
}
fn content_type_requires_preflight(headers: &HeaderMap) -> bool {
let joined_content_type_header_value = if let Ok(combined_headers) = headers
.get_all(header::CONTENT_TYPE)
.iter()
.map(|header_value| {
header_value
.to_str()
.map(|as_str| as_str.trim().replace('\u{0009}', "\u{0020}")) })
.collect::<Result<Vec<_>, _>>()
{
combined_headers.join("\u{002C}\u{0020}") } else {
return false;
};
if let Ok(mime_type) = joined_content_type_header_value.parse::<mime::Mime>() {
!NON_PREFLIGHTED_CONTENT_TYPES.contains(&mime_type.essence_str())
} else {
false
}
}
fn recommended_header_is_provided(headers: &HeaderMap, required_headers: &[String]) -> bool {
required_headers
.iter()
.any(|header| headers.get(header).is_some())
}
register_plugin!("apollo", "csrf", Csrf);
#[cfg(test)]
mod csrf_tests {
use crate::plugin::PluginInit;
#[tokio::test]
async fn plugin_registered() {
crate::plugin::plugins()
.find(|factory| factory.name == "apollo.csrf")
.expect("Plugin not found")
.create_instance_without_schema(&serde_json::json!({ "unsafe_disabled": true }))
.await
.unwrap();
crate::plugin::plugins()
.find(|factory| factory.name == "apollo.csrf")
.expect("Plugin not found")
.create_instance_without_schema(&serde_json::json!({}))
.await
.unwrap();
}
use http::header::CONTENT_TYPE;
use mime::APPLICATION_JSON;
use serde_json_bytes::json;
use tower::ServiceExt;
use super::*;
use crate::plugin::test::MockSupergraphService;
#[tokio::test]
async fn it_lets_preflighted_request_pass_through() {
let config = CSRFConfig::default();
let with_preflight_content_type = supergraph::Request::fake_builder()
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.build()
.unwrap();
assert_accepted(config.clone(), with_preflight_content_type).await;
let with_preflight_header = supergraph::Request::fake_builder()
.header("apollo-require-preflight", "this-is-a-test")
.build()
.unwrap();
assert_accepted(config, with_preflight_header).await;
}
#[tokio::test]
async fn it_rejects_non_preflighted_headers_request() {
let config = CSRFConfig::default();
let mut non_preflighted_request = supergraph::Request::fake_builder().build().unwrap();
non_preflighted_request
.supergraph_request
.headers_mut()
.remove("content-type");
assert_rejected(config, non_preflighted_request).await
}
#[tokio::test]
async fn it_rejects_non_preflighted_content_type_request() {
let config = CSRFConfig::default();
let non_preflighted_request = supergraph::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain")
.build()
.unwrap();
assert_rejected(config.clone(), non_preflighted_request).await;
let non_preflighted_request = supergraph::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain; charset=utf8")
.build()
.unwrap();
assert_rejected(config, non_preflighted_request).await;
}
#[tokio::test]
async fn it_accepts_non_preflighted_headers_request_when_plugin_is_disabled() {
let config = CSRFConfig {
unsafe_disabled: true,
..Default::default()
};
let non_preflighted_request = supergraph::Request::fake_builder().build().unwrap();
assert_accepted(config, non_preflighted_request).await
}
async fn assert_accepted(config: CSRFConfig, request: supergraph::Request) {
let mut mock_service = MockSupergraphService::new();
mock_service.expect_call().times(1).returning(move |_| {
Ok(SupergraphResponse::fake_builder()
.data(json!({ "test": 1234_u32 }))
.build()
.unwrap())
});
let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.supergraph_service(mock_service.boxed());
let res = service_stack
.oneshot(request)
.await
.unwrap()
.next_response()
.await
.unwrap();
assert_eq!(res.errors, []);
assert_eq!(res.data.unwrap(), json!({ "test": 1234_u32 }));
}
async fn assert_rejected(config: CSRFConfig, request: supergraph::Request) {
let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.supergraph_service(MockSupergraphService::new().boxed());
let res = service_stack
.oneshot(request)
.await
.unwrap()
.next_response()
.await
.unwrap();
assert_eq!(
1,
res.errors.len(),
"expected one(1) error in the SupergraphResponse, found {}\n{:?}",
res.errors.len(),
res.errors
);
assert_eq!(
res.errors[0].message,
"This operation has been blocked as a potential Cross-Site Request Forgery (CSRF). \
Please either specify a 'content-type' header \
(with a mime-type that is not one of application/x-www-form-urlencoded, multipart/form-data, text/plain) \
or provide one of the following headers: x-apollo-operation-name, apollo-require-preflight"
);
}
}