use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use http::HeaderName;
use http::HeaderValue;
use http::Request;
use http::Response;
use http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS;
use http::header::ACCESS_CONTROL_ALLOW_HEADERS;
use http::header::ACCESS_CONTROL_ALLOW_METHODS;
use http::header::ACCESS_CONTROL_ALLOW_ORIGIN;
use http::header::ACCESS_CONTROL_EXPOSE_HEADERS;
use http::header::ACCESS_CONTROL_MAX_AGE;
use http::header::ACCESS_CONTROL_REQUEST_HEADERS;
use http::header::ORIGIN;
use http::header::VARY;
use tower::Layer;
use tower::Service;
use crate::configuration::cors::Cors;
use crate::configuration::cors::Policy;
const ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");
const ACCESS_CONTROL_PRIVATE_NETWORK_VALUE: http::HeaderValue = HeaderValue::from_static("true");
const PRIVATE_NETWORK_ACCESS_NAME: HeaderName =
HeaderName::from_static("private-network-access-name");
const PRIVATE_NETWORK_ACCESS_ID: HeaderName = HeaderName::from_static("private-network-access-id");
#[derive(Clone, Debug)]
pub(crate) struct CorsLayer {
config: Cors,
}
impl CorsLayer {
pub(crate) fn new(config: Cors) -> Result<Self, String> {
config
.ensure_usable_cors_rules()
.map_err(|e| e.to_string())?;
if !config.allow_headers.is_empty() {
parse_values::<http::HeaderName>(&config.allow_headers, "allow header name")?;
}
parse_values::<http::Method>(&config.methods, "method")?;
if let Some(headers) = &config.expose_headers {
parse_values::<http::HeaderName>(headers, "expose header name")?;
}
if let Some(policies) = &config.policies {
for policy in policies.iter() {
for origin in policy.origins.iter() {
http::HeaderValue::from_str(origin).map_err(|_| {
format!("origin '{origin}' is not valid: failed to parse header value")
})?;
}
if !policy.allow_headers.is_empty() {
parse_values::<http::HeaderName>(&policy.allow_headers, "allow header name")?;
}
if let Some(methods) = &policy.methods
&& !methods.is_empty()
{
parse_values::<http::Method>(methods, "method")?;
}
if !policy.expose_headers.is_empty() {
parse_values::<http::HeaderName>(&policy.expose_headers, "expose header name")?;
}
}
}
Ok(Self { config })
}
}
impl<S> Layer<S> for CorsLayer {
type Service = CorsService<S>;
fn layer(&self, service: S) -> Self::Service {
CorsService {
inner: service,
config: self.config.clone(),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct CorsService<S> {
inner: S,
config: Cors,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CorsService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static + Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let request_origin = req.headers().get(ORIGIN).cloned();
let is_preflight = req.method() == http::Method::OPTIONS;
let config = self.config.clone();
let request_headers = req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS).cloned();
if is_preflight {
let mut response = Response::builder()
.status(http::StatusCode::OK)
.body(ResBody::default())
.unwrap();
let policy = Self::find_matching_policy(&config, &request_origin);
Self::add_cors_headers(
&mut response,
&config,
&policy,
&request_origin,
true,
request_headers,
);
return Box::pin(async move { Ok(response) });
}
let fut = self.inner.call(req);
Box::pin(async move {
let mut response = fut.await?;
let policy = Self::find_matching_policy(&config, &request_origin);
Self::add_cors_headers(
&mut response,
&config,
&policy,
&request_origin,
false,
request_headers,
);
Ok(response)
})
}
}
impl<S> CorsService<S> {
fn find_matching_policy<'a>(
config: &'a Cors,
origin: &'a Option<http::HeaderValue>,
) -> Option<&'a Policy> {
let origin_str = origin.as_ref()?.to_str().ok()?;
if origin_str == "null" && !config.allow_any_origin {
return None;
}
if let Some(policies) = &config.policies {
for policy in policies.iter() {
for url in policy.origins.iter() {
if &**url == origin_str {
return Some(policy);
}
}
if !policy.match_origins.is_empty() {
for regex in policy.match_origins.iter() {
if regex.is_match(origin_str) {
return Some(policy);
}
}
}
}
}
None
}
fn add_cors_headers<ResBody>(
response: &mut Response<ResBody>,
config: &Cors,
policy: &Option<&Policy>,
request_origin: &Option<http::HeaderValue>,
is_preflight: bool,
request_headers: Option<http::HeaderValue>,
) {
let allow_credentials = policy
.and_then(|p| p.allow_credentials)
.unwrap_or(config.allow_credentials);
let allow_headers = policy
.and_then(|p| {
if p.allow_headers.is_empty() {
None
} else {
Some(&p.allow_headers)
}
})
.unwrap_or(&config.allow_headers);
let expose_headers = if let Some(policy) = policy {
if policy.expose_headers.is_empty() {
config.expose_headers.as_ref()
} else {
Some(&policy.expose_headers)
}
} else {
config.expose_headers.as_ref()
};
let methods = if let Some(policy) = policy {
match &policy.methods {
None => &config.methods,
Some(methods) => methods,
}
} else {
&config.methods
};
let max_age = policy.and_then(|p| p.max_age).or(config.max_age);
if let Some(origin) = request_origin {
if config.allow_any_origin {
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
http::HeaderValue::from_static("*"),
);
} else if policy.is_some() {
response
.headers_mut()
.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
}
if allow_credentials {
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_CREDENTIALS,
http::HeaderValue::from_static("true"),
);
}
if is_preflight {
if !allow_headers.is_empty() {
let header_value = allow_headers.join(", ");
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_HEADERS,
http::HeaderValue::from_str(&header_value)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
} else {
if let Some(request_headers) = request_headers
&& let Ok(headers_str) = request_headers.to_str()
{
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_HEADERS,
http::HeaderValue::from_str(headers_str)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
}
}
if !is_preflight && let Some(headers) = expose_headers {
let header_value = headers.join(", ");
response.headers_mut().insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
http::HeaderValue::from_str(&header_value)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
if is_preflight {
let method_value = methods.join(", ");
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_METHODS,
http::HeaderValue::from_str(&method_value)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
if is_preflight && let Some(max_age) = max_age {
let max_age_secs = max_age.as_secs();
response.headers_mut().insert(
ACCESS_CONTROL_MAX_AGE,
http::HeaderValue::from_str(&max_age_secs.to_string())
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
if is_preflight
&& let Some(Some(pna)) = policy.map(|policy| policy.private_network_access.as_ref())
{
response.headers_mut().insert(
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK,
ACCESS_CONTROL_PRIVATE_NETWORK_VALUE,
);
if let Some(name) = &pna.access_name {
response.headers_mut().insert(
PRIVATE_NETWORK_ACCESS_NAME,
http::HeaderValue::from_str(name)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
if let Some(id) = &pna.access_id {
response.headers_mut().insert(
PRIVATE_NETWORK_ACCESS_ID,
http::HeaderValue::from_str(id)
.unwrap_or_else(|_| http::HeaderValue::from_static("")),
);
}
}
Self::append_vary_header(response, ORIGIN);
if is_preflight {
Self::append_vary_header(response, ACCESS_CONTROL_REQUEST_HEADERS);
}
}
fn append_vary_header<ResBody>(response: &mut Response<ResBody>, value: http::HeaderName) {
let headers = response.headers_mut();
if let Some(existing_vary) = headers.get(VARY) {
if let Ok(existing_str) = existing_vary.to_str() {
let mut existing_values = existing_str.split(',').map(|v| v.trim());
if !existing_values.any(|existing| existing.eq_ignore_ascii_case(value.as_str())) {
let new_vary = format!("{existing_str}, {value}");
let new_header_value = http::HeaderValue::from_str(&new_vary)
.expect("combining pre-existing header + hardcoded valid value can not produce an invalid result");
headers.insert(VARY, new_header_value);
}
} else {
let lossy_str = String::from_utf8_lossy(existing_vary.as_bytes());
tracing::error!(
"could not append Vary header, because the existing value is not UTF-8: {lossy_str}"
);
}
} else {
headers.insert(VARY, http::HeaderValue::from(value));
}
}
}
fn parse_values<T>(values_to_parse: &[Arc<str>], error_description: &str) -> Result<Vec<T>, String>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let mut errors = Vec::new();
let mut values = Vec::new();
for val in values_to_parse {
match val
.parse::<T>()
.map_err(|err| format!("{error_description} '{val}' is not valid: {err}"))
{
Ok(val) => values.push(val),
Err(err) => errors.push(err),
}
}
if errors.is_empty() {
Ok(values)
} else {
Err(errors.join(", "))
}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use http::Request;
use http::Response;
use http::StatusCode;
use http::header::ACCESS_CONTROL_ALLOW_ORIGIN;
use http::header::ACCESS_CONTROL_EXPOSE_HEADERS;
use http::header::ORIGIN;
use tower::Service;
use super::*;
use crate::configuration::cors::Cors;
use crate::configuration::cors::Policy;
use crate::configuration::cors::PrivateNetworkAccessPolicy;
struct DummyService;
impl Service<Request<()>> for DummyService {
type Response = Response<&'static str>;
type Error = ();
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
std::future::ready(Ok(Response::builder()
.status(StatusCode::OK)
.body("ok")
.unwrap()))
}
}
#[test]
fn test_bad_allow_headers_cors_configuration() {
let cors = Cors::builder()
.allow_headers(vec![String::from("bad\nname")])
.build();
let layer = CorsLayer::new(cors);
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("allow header name 'bad\nname' is not valid: invalid HTTP header name")
);
}
#[test]
fn test_bad_allow_methods_cors_configuration() {
let cors = Cors::builder()
.methods(vec![String::from("bad\nmethod")])
.build();
let layer = CorsLayer::new(cors);
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("method 'bad\nmethod' is not valid: invalid HTTP method")
);
}
#[test]
fn test_bad_origins_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![String::from("bad\norigin")])
.build(),
])
.build();
let layer = CorsLayer::new(cors);
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("origin 'bad\norigin' is not valid: failed to parse header value")
);
}
#[test]
fn test_good_cors_configuration() {
let cors = Cors::builder()
.allow_headers(vec![String::from("good-name")])
.build();
let layer = CorsLayer::new(cors);
assert!(layer.is_ok());
}
#[test]
fn test_non_preflight_cors_headers() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://trusted.com".into()])
.expose_headers(vec!["x-custom-header".into()])
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://trusted.com")
.body(())
.unwrap();
let fut = service.call(req);
let resp = futures::executor::block_on(fut).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
"https://trusted.com"
);
assert_eq!(
headers.get(ACCESS_CONTROL_EXPOSE_HEADERS).unwrap(),
"x-custom-header"
);
}
#[test]
fn test_expose_headers_non_preflight_set() {
let cors = Cors::builder()
.expose_headers(vec!["x-foo".into(), "x-bar".into()])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(ACCESS_CONTROL_EXPOSE_HEADERS).unwrap(),
"x-foo, x-bar"
);
}
#[test]
fn test_expose_headers_non_preflight_not_set() {
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_EXPOSE_HEADERS).is_none());
}
#[test]
fn test_mirror_request_headers_preflight() {
let cors = Cors::builder().allow_headers(vec![]).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "x-foo, x-bar")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_headers = headers.get(ACCESS_CONTROL_ALLOW_HEADERS).unwrap();
assert_eq!(allow_headers, "x-foo, x-bar");
}
#[test]
fn test_no_mirror_request_headers_non_preflight() {
let cors = Cors::builder().allow_headers(vec![]).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "x-foo, x-bar")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
}
#[test]
fn test_cors_headers_comma_separated_format() {
let cors = Cors::builder()
.allow_headers(vec![
"content-type".into(),
"authorization".into(),
"x-custom".into(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_headers = headers.get(ACCESS_CONTROL_ALLOW_HEADERS).unwrap();
assert_eq!(allow_headers, "content-type, authorization, x-custom");
let all_headers = headers.get_all(ACCESS_CONTROL_ALLOW_HEADERS);
assert_eq!(all_headers.iter().count(), 1);
}
#[test]
fn test_cors_methods_comma_separated_format() {
let cors = Cors {
allow_any_origin: false,
allow_credentials: false,
allow_headers: Arc::new([]),
expose_headers: None,
methods: Arc::new(["GET".into(), "POST".into(), "PUT".into()]),
max_age: None,
policies: None,
};
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_methods = headers.get(ACCESS_CONTROL_ALLOW_METHODS).unwrap();
assert_eq!(allow_methods, "GET, POST, PUT");
let all_methods = headers.get_all(ACCESS_CONTROL_ALLOW_METHODS);
assert_eq!(all_methods.iter().count(), 1);
}
#[test]
fn test_policy_methods_fallback_to_global() {
let cors = Cors::builder()
.methods(vec!["POST".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://example.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_methods = headers.get(ACCESS_CONTROL_ALLOW_METHODS).unwrap();
assert_eq!(allow_methods, "POST");
}
#[test]
fn test_policy_empty_methods_runtime() {
let cors = Cors::builder()
.methods(vec!["POST".into(), "PUT".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec![])
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://example.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_methods = headers.get(ACCESS_CONTROL_ALLOW_METHODS).unwrap();
assert_eq!(allow_methods, "");
}
#[test]
fn test_policy_specific_methods_runtime() {
let cors = Cors::builder()
.methods(vec!["POST".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec!["GET".into(), "DELETE".into()])
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://example.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let allow_methods = headers.get(ACCESS_CONTROL_ALLOW_METHODS).unwrap();
assert_eq!(allow_methods, "GET, DELETE");
}
#[test]
fn test_null_origin_rejected_with_catch_all_regex() {
let cors = Cors::builder()
.allow_any_origin(false)
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![regex::Regex::new(".*").unwrap()])
.allow_credentials(false)
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/").header(ORIGIN, "null").body(()).unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_null_origin_rejected_with_specific_regex() {
let cors = Cors::builder()
.allow_any_origin(false)
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![regex::Regex::new("n.ll").unwrap()])
.allow_credentials(false)
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/").header(ORIGIN, "null").body(()).unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_null_origin_allowed_with_allow_any_origin() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/").header(ORIGIN, "null").body(()).unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "*");
}
#[test]
fn test_regular_origins_still_work_with_null_guard() {
let cors = Cors::builder()
.allow_any_origin(false)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.build(),
Policy::builder()
.origins(vec![])
.match_origins(vec![regex::Regex::new("https://.*\\.test\\.com").unwrap()])
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://example.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
"https://example.com"
);
let req2 = Request::get("/")
.header(ORIGIN, "https://api.test.com")
.body(())
.unwrap();
let resp2 = futures::executor::block_on(service.call(req2)).unwrap();
let headers2 = resp2.headers();
assert_eq!(
headers2.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
"https://api.test.com"
);
let req3 = Request::get("/")
.header(ORIGIN, "https://malicious.com")
.body(())
.unwrap();
let resp3 = futures::executor::block_on(service.call(req3)).unwrap();
let headers3 = resp3.headers();
assert!(headers3.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_null_origin_preflight_request_rejected() {
let cors = Cors::builder()
.allow_any_origin(false)
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![regex::Regex::new(".*").unwrap()])
.allow_credentials(false)
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "null")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_null_origin_preflight_allowed_with_allow_any_origin() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "null")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "*");
}
#[test]
fn test_allow_any_origin() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "http://example.com/")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "*");
}
#[test]
fn test_allow_any_origin_nocors() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/").body(()).unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_allow_any_origin_preflight() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::options("/")
.header(ORIGIN, "http://example.com/")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), "*");
}
#[test]
fn test_allow_any_origin_nocors_preflight() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::options("/").body(()).unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
}
#[test]
fn test_vary_header_set_for_cors_requests() {
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(headers.get(VARY).unwrap(), "origin");
}
#[test]
fn test_vary_header_preserves_existing_values() {
struct VaryService;
impl Service<Request<()>> for VaryService {
type Response = Response<&'static str>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
Box::pin(async {
Ok(Response::builder()
.status(StatusCode::OK)
.header(VARY, "Accept-Encoding, User-Agent")
.body("ok")
.unwrap())
})
}
}
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(VaryService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(VARY).unwrap(),
"Accept-Encoding, User-Agent, origin"
);
}
#[test]
fn test_vary_header_no_duplicates() {
struct VaryWithOriginService;
impl Service<Request<()>> for VaryWithOriginService {
type Response = Response<&'static str>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
Box::pin(async {
Ok(Response::builder()
.status(StatusCode::OK)
.header(VARY, "accept-encoding, origin, user-agent")
.body("ok")
.unwrap())
})
}
}
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(VaryWithOriginService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(VARY).unwrap(),
"accept-encoding, origin, user-agent"
);
}
#[test]
fn test_vary_header_no_duplicates_case_insensitive() {
struct VaryWithOriginService;
impl Service<Request<()>> for VaryWithOriginService {
type Response = Response<&'static str>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
Box::pin(async {
Ok(Response::builder()
.status(StatusCode::OK)
.header(VARY, "Accept-Encoding, Origin, User-Agent")
.body("ok")
.unwrap())
})
}
}
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(VaryWithOriginService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
assert_eq!(
headers.get(VARY).unwrap(),
"Accept-Encoding, Origin, User-Agent"
);
}
#[test]
fn test_vary_header_preflight_includes_request_headers() {
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let vary_header = headers.get(VARY).unwrap().to_str().unwrap();
assert!(vary_header.contains("origin"));
assert!(vary_header.contains("access-control-request-headers"));
}
#[test]
fn test_vary_header_non_preflight_only_origin() {
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "content-type")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let vary_header = headers.get(VARY).unwrap().to_str().unwrap();
assert_eq!(vary_header, "origin");
}
#[test]
fn test_vary_header_preserves_complex_existing_values_non_preflight() {
struct ComplexVaryService;
impl Service<Request<()>> for ComplexVaryService {
type Response = Response<&'static str>;
type Error = ();
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
Box::pin(async {
Ok(Response::builder()
.status(StatusCode::OK)
.header(VARY, "Accept-Language, Accept-Encoding")
.body("ok")
.unwrap())
})
}
}
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(ComplexVaryService);
let req = Request::get("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let vary_header = headers.get(VARY).unwrap().to_str().unwrap();
assert_eq!(vary_header, "Accept-Language, Accept-Encoding, origin");
}
#[test]
fn test_vary_header_preflight_only_cors_headers() {
let cors = Cors::builder().build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let headers = resp.headers();
let vary_header = headers.get(VARY).unwrap().to_str().unwrap();
assert_eq!(vary_header, "origin, access-control-request-headers");
}
#[test]
fn pna_without_other_headers() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(PrivateNetworkAccessPolicy::builder().build())
.origin("https://studio.apollographql.com")
.build(),
])
.build();
println!("{cors:?}");
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let finder = |header| resp.headers().iter().find(|h| h.0 == header);
assert!(finder(ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK).is_some());
assert!(finder(PRIVATE_NETWORK_ACCESS_NAME).is_none());
assert!(finder(PRIVATE_NETWORK_ACCESS_ID).is_none());
}
#[test]
fn pna_with_access_name() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_name("ferris-server")
.build(),
)
.origin("https://studio.apollographql.com")
.build(),
])
.allow_any_origin(true)
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
let finder = |header| resp.headers().iter().find(|h| h.0 == header);
assert!(finder(ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK).is_some());
assert!(
finder(PRIVATE_NETWORK_ACCESS_NAME).is_some_and(|(_, name)| name == "ferris-server")
);
assert!(finder(PRIVATE_NETWORK_ACCESS_ID).is_none());
}
#[test]
fn pna_with_access_name_and_id() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_name("ferris-server")
.access_id("01:23:45:67:89:0A")
.build(),
)
.origin("https://studio.apollographql.com")
.build(),
])
.build();
let layer = CorsLayer::new(cors).unwrap();
let mut service = layer.layer(DummyService);
let req = Request::builder()
.method("OPTIONS")
.uri("/")
.header(ORIGIN, "https://studio.apollographql.com")
.body(())
.unwrap();
let resp = futures::executor::block_on(service.call(req)).unwrap();
resp.headers()
.iter()
.for_each(|header| println!("{header:?}"));
let finder = |header| resp.headers().iter().find(|h| h.0 == header);
assert!(finder(ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK).is_some());
assert!(
finder(PRIVATE_NETWORK_ACCESS_NAME).is_some_and(|(_, name)| name == "ferris-server")
);
assert!(finder(PRIVATE_NETWORK_ACCESS_ID).is_some_and(|(_, id)| id == "01:23:45:67:89:0A"));
}
}