#![deny(
const_err,
dead_code,
deprecated,
arithmetic_overflow,
improper_ctypes,
missing_docs,
mutable_transmutes,
no_mangle_const_items,
non_camel_case_types,
non_shorthand_field_patterns,
non_upper_case_globals,
overflowing_literals,
path_statements,
stable_features,
trivial_casts,
trivial_numeric_casts,
unconditional_recursion,
unknown_crate_types,
unreachable_code,
unused_allocation,
unused_assignments,
unused_attributes,
unused_comparisons,
unused_extern_crates,
unused_features,
unused_imports,
unused_import_braces,
unused_qualifications,
unused_must_use,
unused_mut,
unused_parens,
unused_results,
unused_unsafe,
variant_size_differences,
warnings,
while_true
)]
#![allow(
missing_copy_implementations,
missing_debug_implementations,
unknown_lints,
unsafe_code,
rustdoc::broken_intra_doc_links
)]
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
#[cfg(test)]
#[macro_use]
mod test_macros;
mod fairing;
pub mod headers;
use std::borrow::Cow;
use std::collections::HashSet;
use std::error;
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::str::FromStr;
#[allow(unused_imports)]
use ::log::{debug, error, info};
use regex::RegexSet;
use rocket::http::{self, Status};
use rocket::request::{FromRequest, Request};
use rocket::response;
use rocket::{debug_, error_, info_, outcome::Outcome, State};
#[cfg(feature = "serialization")]
use serde_derive::{Deserialize, Serialize};
use crate::headers::{
AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet,
Origin,
};
#[derive(Debug)]
pub enum Error {
MissingOrigin,
BadOrigin(url::ParseError),
OpaqueAllowedOrigin(Vec<String>),
MissingRequestMethod,
BadRequestMethod,
MissingRequestHeaders,
OriginNotAllowed(String),
MethodNotAllowed(String),
RegexError(regex::Error),
HeadersNotAllowed,
CredentialsWithWildcardOrigin,
MissingCorsInRocketState,
MissingInjectedHeader,
}
impl Error {
fn status(&self) -> Status {
match *self {
Error::MissingOrigin
| Error::OriginNotAllowed(_)
| Error::MethodNotAllowed(_)
| Error::HeadersNotAllowed => Status::Forbidden,
Error::CredentialsWithWildcardOrigin
| Error::MissingCorsInRocketState
| Error::MissingInjectedHeader => Status::InternalServerError,
_ => Status::BadRequest,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::MissingOrigin => write!(
f,
"The request header `Origin` is \
required but is missing"
),
Error::BadOrigin(_) => write!(f, "The request header `Origin` contains an invalid URL"),
Error::MissingRequestMethod => write!(
f,
"The request header `Access-Control-Request-Method` \
is required but is missing"
),
Error::BadRequestMethod => write!(
f,
"The request header `Access-Control-Request-Method` has an invalid value"
),
Error::MissingRequestHeaders => write!(
f,
"The request header `Access-Control-Request-Headers` \
is required but is missing"
),
Error::OriginNotAllowed(origin) => write!(
f,
"Origin '{}' is \
not allowed to request",
origin
),
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
Error::CredentialsWithWildcardOrigin => write!(
f,
"Credentials are allowed, but the Origin is set to \"*\". \
This is not allowed by W3C"
),
Error::MissingCorsInRocketState => write!(
f,
"A CORS Request Guard was used, but no CORS Options \
was available in Rocket's state"
),
Error::MissingInjectedHeader => {
write!(f,
"The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug.")
}
Error::OpaqueAllowedOrigin(ref origins) => write!(
f,
"The configured Origins '{}' are Opaque Origins. \
Use regex instead.",
origins.join("; ")
),
Error::RegexError(ref e) => write!(f, "{}", e),
}
}
}
impl error::Error for Error {
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
Error::BadOrigin(ref e) => Some(e),
_ => Some(self),
}
}
}
impl<'r, 'o: 'r> response::Responder<'r, 'o> for Error {
fn respond_to(self, _: &Request<'_>) -> Result<response::Response<'o>, Status> {
error_!("CORS Error: {}", self);
Err(self.status())
}
}
impl From<url::ParseError> for Error {
fn from(error: url::ParseError) -> Self {
Error::BadOrigin(error)
}
}
impl From<regex::Error> for Error {
fn from(error: regex::Error) -> Self {
Error::RegexError(error)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub enum AllOrSome<T> {
All,
Some(T),
}
impl<T> Default for AllOrSome<T> {
fn default() -> Self {
AllOrSome::All
}
}
impl<T> AllOrSome<T> {
pub fn is_all(&self) -> bool {
match self {
AllOrSome::All => true,
AllOrSome::Some(_) => false,
}
}
pub fn is_some(&self) -> bool {
!self.is_all()
}
pub fn unwrap(self) -> T {
match self {
AllOrSome::All => panic!("Attempting to unwrap an `All`"),
AllOrSome::Some(inner) => inner,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct Method(http::Method);
impl FromStr for Method {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let method = http::Method::from_str(s)?;
Ok(Method(method))
}
}
impl Deref for Method {
type Target = http::Method;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<http::Method> for Method {
fn from(method: http::Method) -> Self {
Method(method)
}
}
impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
#[cfg(feature = "serialization")]
mod method_serde {
use std::fmt;
use std::str::FromStr;
use serde::{self, Deserialize, Serialize};
use crate::Method;
impl Serialize for Method {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for Method {
fn deserialize<D>(deserializer: D) -> Result<Method, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct MethodVisitor;
impl<'de> Visitor<'de> for MethodVisitor {
type Value = Method;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a string containing a HTTP Verb")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match Self::Value::from_str(s) {
Ok(value) => Ok(value),
Err(e) => Err(de::Error::custom(format!("{:?}", e))),
}
}
}
deserializer.deserialize_string(MethodVisitor)
}
}
}
pub type AllowedOrigins = AllOrSome<Origins>;
impl AllowedOrigins {
#[allow(clippy::needless_lifetimes)]
pub fn some<'a, 'b, S1: AsRef<str>, S2: AsRef<str>>(exact: &'a [S1], regex: &'b [S2]) -> Self {
AllOrSome::Some(Origins {
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
pub fn some_exact<S: AsRef<str>>(exact: &[S]) -> Self {
AllOrSome::Some(Origins {
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
pub fn some_regex<S: AsRef<str>>(regex: &[S]) -> Self {
AllOrSome::Some(Origins {
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
pub fn some_null() -> Self {
AllOrSome::Some(Origins {
allow_null: true,
..Default::default()
})
}
pub fn all() -> Self {
AllOrSome::All
}
}
#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serialization", serde(default))]
pub struct Origins {
#[cfg_attr(feature = "serialization", serde(default))]
pub allow_null: bool,
#[cfg_attr(feature = "serialization", serde(default))]
pub exact: Option<HashSet<String>>,
#[cfg_attr(feature = "serialization", serde(default))]
pub regex: Option<HashSet<String>>,
}
#[derive(Clone, Debug)]
pub(crate) struct ParsedAllowedOrigins {
pub allow_null: bool,
pub exact: HashSet<url::Origin>,
pub regex: Option<RegexSet>,
}
impl ParsedAllowedOrigins {
fn parse(origins: &Origins) -> Result<Self, Error> {
let exact: Result<Vec<(&str, url::Origin)>, Error> = match &origins.exact {
Some(exact) => exact
.iter()
.map(|url| Ok((url.as_str(), to_origin(url.as_str())?)))
.collect(),
None => Ok(Default::default()),
};
let exact = exact?;
let (tuple, opaque): (Vec<_>, Vec<_>) =
exact.into_iter().partition(|(_, url)| url.is_tuple());
if !opaque.is_empty() {
return Err(Error::OpaqueAllowedOrigin(
opaque
.into_iter()
.map(|(original, _)| original.to_string())
.collect(),
));
}
let exact = tuple.into_iter().map(|(_, url)| url).collect();
let regex = match &origins.regex {
None => None,
Some(ref regex) => Some(RegexSet::new(regex)?),
};
Ok(Self {
allow_null: origins.allow_null,
exact,
regex,
})
}
fn verify(&self, origin: &Origin) -> bool {
info_!("Verifying origin: {}", origin);
match origin {
Origin::Null => {
info_!("Origin is null. Allowing? {}", self.allow_null);
self.allow_null
}
Origin::Parsed(ref parsed) => {
assert!(
parsed.is_tuple(),
"Parsed Origin is not tuple. This is a bug. Please report"
);
if self.exact.get(parsed).is_some() {
info_!("Origin has an exact match");
return true;
}
if let Some(regex_set) = &self.regex {
let regex_match = regex_set.is_match(&parsed.ascii_serialization());
debug_!("Matching against regex set {:#?}", regex_set);
info_!("Origin has a regex match? {}", regex_match);
return regex_match;
}
info!("Origin does not match anything");
false
}
Origin::Opaque(ref opaque) => {
if let Some(regex_set) = &self.regex {
let regex_match = regex_set.is_match(opaque);
debug_!("Matching against regex set {:#?}", regex_set);
info_!("Origin has a regex match? {}", regex_match);
return regex_match;
}
info!("Origin does not match anything");
false
}
}
}
}
pub type AllowedMethods = HashSet<Method>;
pub type AllowedHeaders = AllOrSome<HashSet<HeaderFieldName>>;
impl AllowedHeaders {
pub fn some(headers: &[&str]) -> Self {
AllOrSome::Some(headers.iter().map(|s| (*s).to_string().into()).collect())
}
pub fn all() -> Self {
AllOrSome::All
}
}
#[derive(Eq, PartialEq, Clone, Debug)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct CorsOptions {
#[cfg_attr(feature = "serialization", serde(default))]
pub allowed_origins: AllowedOrigins,
#[cfg_attr(
feature = "serialization",
serde(default = "CorsOptions::default_allowed_methods")
)]
pub allowed_methods: AllowedMethods,
#[cfg_attr(feature = "serialization", serde(default))]
pub allowed_headers: AllowedHeaders,
#[cfg_attr(feature = "serialization", serde(default))]
pub allow_credentials: bool,
#[cfg_attr(feature = "serialization", serde(default))]
pub expose_headers: HashSet<String>,
#[cfg_attr(feature = "serialization", serde(default))]
pub max_age: Option<usize>,
#[cfg_attr(feature = "serialization", serde(default))]
pub send_wildcard: bool,
#[cfg_attr(
feature = "serialization",
serde(default = "CorsOptions::default_fairing_route_base")
)]
pub fairing_route_base: String,
#[cfg_attr(
feature = "serialization",
serde(default = "CorsOptions::default_fairing_route_rank")
)]
pub fairing_route_rank: isize,
}
impl Default for CorsOptions {
fn default() -> Self {
Self {
allowed_origins: Default::default(),
allowed_methods: Self::default_allowed_methods(),
allowed_headers: Default::default(),
allow_credentials: Default::default(),
expose_headers: Default::default(),
max_age: Default::default(),
send_wildcard: Default::default(),
fairing_route_base: Self::default_fairing_route_base(),
fairing_route_rank: Self::default_fairing_route_rank(),
}
}
}
impl CorsOptions {
fn default_allowed_methods() -> HashSet<Method> {
use rocket::http::Method;
vec![
Method::Get,
Method::Head,
Method::Post,
Method::Options,
Method::Put,
Method::Patch,
Method::Delete,
]
.into_iter()
.map(From::from)
.collect()
}
fn default_fairing_route_base() -> String {
"/cors".to_string()
}
fn default_fairing_route_rank() -> isize {
0
}
pub fn validate(&self) -> Result<(), Error> {
if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials {
return Err(Error::CredentialsWithWildcardOrigin);
}
Ok(())
}
pub fn to_cors(&self) -> Result<Cors, Error> {
Cors::from_options(self)
}
pub fn allowed_origins(mut self, allowed_origins: AllowedOrigins) -> Self {
self.allowed_origins = allowed_origins;
self
}
pub fn allowed_methods(mut self, allowed_methods: AllowedMethods) -> Self {
self.allowed_methods = allowed_methods;
self
}
pub fn allowed_headers(mut self, allowed_headers: AllowedHeaders) -> Self {
self.allowed_headers = allowed_headers;
self
}
pub fn allow_credentials(mut self, allow_credentials: bool) -> Self {
self.allow_credentials = allow_credentials;
self
}
pub fn expose_headers(mut self, expose_headers: HashSet<String>) -> Self {
self.expose_headers = expose_headers;
self
}
pub fn max_age(mut self, max_age: Option<usize>) -> Self {
self.max_age = max_age;
self
}
pub fn send_wildcard(mut self, send_wildcard: bool) -> Self {
self.send_wildcard = send_wildcard;
self
}
pub fn fairing_route_base<S: Into<String>>(mut self, fairing_route_base: S) -> Self {
self.fairing_route_base = fairing_route_base.into();
self
}
pub fn fairing_route_rank(mut self, fairing_route_rank: isize) -> Self {
self.fairing_route_rank = fairing_route_rank;
self
}
}
#[derive(Clone, Debug)]
pub struct Cors {
pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>,
pub(crate) allowed_methods: AllowedMethods,
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
pub(crate) allow_credentials: bool,
pub(crate) expose_headers: HashSet<String>,
pub(crate) max_age: Option<usize>,
pub(crate) send_wildcard: bool,
pub(crate) fairing_route_base: String,
pub(crate) fairing_route_rank: isize,
}
impl Cors {
pub fn from_options(options: &CorsOptions) -> Result<Self, Error> {
options.validate()?;
let allowed_origins = parse_allowed_origins(&options.allowed_origins)?;
Ok(Cors {
allowed_origins,
allowed_methods: options.allowed_methods.clone(),
allowed_headers: options.allowed_headers.clone(),
allow_credentials: options.allow_credentials,
expose_headers: options.expose_headers.clone(),
max_age: options.max_age,
send_wildcard: options.send_wildcard,
fairing_route_base: options.fairing_route_base.clone(),
fairing_route_rank: options.fairing_route_rank,
})
}
pub fn respond_owned<'r, 'o: 'r, F, R>(
self,
handler: F,
) -> Result<ManualResponder<'r, F, R>, Error>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r, 'o>,
{
Ok(ManualResponder::new(Cow::Owned(self), handler))
}
pub fn respond_borrowed<'r, 'o: 'r, F, R>(
&'r self,
handler: F,
) -> Result<ManualResponder<'r, F, R>, Error>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r, 'o>,
{
Ok(ManualResponder::new(Cow::Borrowed(self), handler))
}
}
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct Response {
allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet,
allow_credentials: bool,
expose_headers: HeaderFieldNamesSet,
max_age: Option<usize>,
vary_origin: bool,
}
impl Response {
fn new() -> Self {
Self {
allow_origin: None,
allow_headers: HashSet::new(),
allow_methods: HashSet::new(),
allow_credentials: false,
expose_headers: HashSet::new(),
max_age: None,
vary_origin: false,
}
}
fn origin(mut self, origin: &str, vary_origin: bool) -> Self {
self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
self.vary_origin = vary_origin;
self
}
fn any(mut self) -> Self {
self.allow_origin = Some(AllOrSome::All);
self
}
fn credentials(mut self, value: bool) -> Self {
self.allow_credentials = value;
self
}
fn exposed_headers(mut self, headers: &[&str]) -> Self {
self.expose_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
self
}
fn max_age(mut self, value: Option<usize>) -> Self {
self.max_age = value;
self
}
fn methods(mut self, methods: &HashSet<Method>) -> Self {
self.allow_methods = methods.clone();
self
}
fn headers(mut self, headers: &[&str]) -> Self {
self.allow_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
self
}
pub fn responder<'r, 'o: 'r, R: response::Responder<'r, 'o>>(
self,
responder: R,
) -> Responder<R> {
Responder::new(responder, self)
}
pub fn response<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
let mut response = response::Response::build_from(base).finalize();
self.merge(&mut response);
response
}
fn merge(&self, response: &mut response::Response<'_>) {
let origin = match self.allow_origin {
None => {
return;
}
Some(ref origin) => origin,
};
let origin = match *origin {
AllOrSome::All => "*".to_string(),
AllOrSome::Some(ref origin) => origin.to_string(),
};
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
if self.allow_credentials {
let _ = response.set_raw_header("Access-Control-Allow-Credentials", "true");
} else {
response.remove_header("Access-Control-Allow-Credentials");
}
if !self.expose_headers.is_empty() {
let headers: Vec<String> = self
.expose_headers
.iter()
.map(|s| s.deref().to_string())
.collect();
let headers = headers.join(", ");
let _ = response.set_raw_header("Access-Control-Expose-Headers", headers);
} else {
response.remove_header("Access-Control-Expose-Headers");
}
if !self.allow_headers.is_empty() {
let headers: Vec<String> = self
.allow_headers
.iter()
.map(|s| s.deref().to_string())
.collect();
let headers = headers.join(", ");
let _ = response.set_raw_header("Access-Control-Allow-Headers", headers);
} else {
response.remove_header("Access-Control-Allow-Headers");
}
if !self.allow_methods.is_empty() {
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
let methods = methods.join(", ");
let _ = response.set_raw_header("Access-Control-Allow-Methods", methods);
} else {
response.remove_header("Access-Control-Allow-Methods");
}
if self.max_age.is_some() {
let max_age = self.max_age.unwrap();
let _ = response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
} else {
response.remove_header("Access-Control-Max-Age");
}
if self.vary_origin {
response.adjoin_raw_header("Vary", "Origin");
}
}
pub fn validate_and_build<'a, 'r>(
options: &'a Cors,
request: &'a Request<'r>,
) -> Result<Self, Error> {
validate_and_build(options, request)
}
}
pub struct Guard<'r> {
response: Response,
marker: PhantomData<&'r Response>,
}
impl<'r, 'o: 'r> Guard<'r> {
fn new(response: Response) -> Self {
Self {
response,
marker: PhantomData,
}
}
pub fn responder<R: response::Responder<'r, 'o>>(self, responder: R) -> Responder<R> {
self.response.responder(responder)
}
pub fn response(&self, base: response::Response<'r>) -> response::Response<'r> {
self.response.response(base)
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Guard<'r> {
type Error = Error;
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
let options = match request.guard::<&State<Cors>>().await {
Outcome::Success(options) => options,
_ => {
let error = Error::MissingCorsInRocketState;
return Outcome::Failure((error.status(), error));
}
};
match Response::validate_and_build(options, request) {
Ok(response) => Outcome::Success(Self::new(response)),
Err(error) => Outcome::Failure((error.status(), error)),
}
}
}
#[derive(Debug)]
pub struct Responder<R> {
responder: R,
cors_response: Response,
}
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> Responder<R> {
fn new(responder: R, cors_response: Response) -> Self {
Self {
responder,
cors_response,
}
}
fn respond(self, request: &'r Request<'_>) -> response::Result<'o> {
let mut response = self.responder.respond_to(request)?; self.cors_response.merge(&mut response);
Ok(response)
}
}
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> response::Responder<'r, 'o> for Responder<R> {
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
self.respond(request)
}
}
pub struct ManualResponder<'r, F, R> {
options: Cow<'r, Cors>,
handler: F,
marker: PhantomData<R>,
}
impl<'r, 'o: 'r, F, R> ManualResponder<'r, F, R>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r, 'o>,
{
fn new(options: Cow<'r, Cors>, handler: F) -> Self {
let marker = PhantomData;
Self {
options,
handler,
marker,
}
}
fn build_guard(&self, request: &Request<'_>) -> Result<Guard<'r>, Error> {
let response = Response::validate_and_build(&self.options, request)?;
Ok(Guard::new(response))
}
}
impl<'r, 'o: 'r, F, R> response::Responder<'r, 'o> for ManualResponder<'r, F, R>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r, 'o>,
{
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
let guard = match self.build_guard(request) {
Ok(guard) => guard,
Err(err) => {
error_!("CORS error: {}", err);
return Err(err.status());
}
};
(self.handler)(guard).respond_to(request)
}
}
#[derive(Debug, Eq, PartialEq)]
#[allow(variant_size_differences)]
enum ValidationResult {
None,
Preflight {
origin: String,
headers: Option<AccessControlRequestHeaders>,
},
Request { origin: String },
}
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
Ok(url::Url::parse(origin.as_ref())?.origin())
}
fn parse_allowed_origins(
origins: &AllowedOrigins,
) -> Result<AllOrSome<ParsedAllowedOrigins>, Error> {
match origins {
AllOrSome::All => Ok(AllOrSome::All),
AllOrSome::Some(origins) => {
let parsed = ParsedAllowedOrigins::parse(origins)?;
Ok(AllOrSome::Some(parsed))
}
}
}
fn validate_and_build(options: &Cors, request: &Request<'_>) -> Result<Response, Error> {
let result = validate(options, request)?;
Ok(match result {
ValidationResult::None => Response::new(),
ValidationResult::Preflight { origin, headers } => {
preflight_response(options, &origin, headers.as_ref())
}
ValidationResult::Request { origin } => actual_request_response(options, &origin),
})
}
fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, Error> {
let origin = origin(request)?;
let origin = match origin {
None => {
return Ok(ValidationResult::None);
}
Some(origin) => origin,
};
match request.method() {
http::Method::Options => {
let method = request_method(request)?;
let headers = request_headers(request)?;
preflight_validate(options, &origin, &method, &headers)?;
Ok(ValidationResult::Preflight {
origin: origin.to_string(),
headers,
})
}
_ => {
actual_request_validate(options, &origin)?;
Ok(ValidationResult::Request {
origin: origin.to_string(),
})
}
}
}
fn validate_origin(
origin: &Origin,
allowed_origins: &AllOrSome<ParsedAllowedOrigins>,
) -> Result<(), Error> {
match *allowed_origins {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => {
if allowed_origins.verify(origin) {
Ok(())
} else {
Err(Error::OriginNotAllowed(origin.to_string()))
}
}
}
}
fn validate_allowed_method(
method: &AccessControlRequestMethod,
allowed_methods: &AllowedMethods,
) -> Result<(), Error> {
let &AccessControlRequestMethod(ref request_method) = method;
if !allowed_methods.iter().any(|m| m == request_method) {
return Err(Error::MethodNotAllowed(method.0.to_string()));
}
Ok(())
}
fn validate_allowed_headers(
headers: &AccessControlRequestHeaders,
allowed_headers: &AllowedHeaders,
) -> Result<(), Error> {
let &AccessControlRequestHeaders(ref headers) = headers;
match *allowed_headers {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => {
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
return Err(Error::HeadersNotAllowed);
}
Ok(())
}
}
}
fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
match Origin::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(origin) => Ok(Some(origin)),
Outcome::Failure((_, err)) => Err(err),
}
}
fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMethod>, Error> {
match AccessControlRequestMethod::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(method) => Ok(Some(method)),
Outcome::Failure((_, err)) => Err(err),
}
}
fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestHeaders>, Error> {
match AccessControlRequestHeaders::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(geaders) => Ok(Some(geaders)),
Outcome::Failure((_, err)) => Err(err),
}
}
fn preflight_validate(
options: &Cors,
origin: &Origin,
method: &Option<AccessControlRequestMethod>,
headers: &Option<AccessControlRequestHeaders>,
) -> Result<(), Error> {
validate_origin(origin, &options.allowed_origins)?;
let method = method.as_ref().ok_or(Error::MissingRequestMethod)?;
validate_allowed_method(method, &options.allowed_methods)?;
if let Some(ref headers) = *headers {
validate_allowed_headers(headers, &options.allowed_headers)?;
}
Ok(())
}
fn preflight_response(
options: &Cors,
origin: &str,
headers: Option<&AccessControlRequestHeaders>,
) -> Response {
let response = Response::new();
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin, true)
}
}
AllOrSome::Some(_) => response.origin(origin, false),
};
let response = response.credentials(options.allow_credentials);
let response = response.max_age(options.max_age);
let response = response.methods(&options.allowed_methods);
if let Some(headers) = headers {
let &AccessControlRequestHeaders(ref headers) = headers;
response.headers(
headers
.iter()
.map(|s| &**s.deref())
.collect::<Vec<&str>>()
.as_slice(),
)
} else {
response
}
}
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
validate_origin(origin, &options.allowed_origins)?;
Ok(())
}
fn actual_request_response(options: &Cors, origin: &str) -> Response {
let response = Response::new();
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin, true)
}
}
AllOrSome::Some(_) => response.origin(origin, false),
};
let response = response.credentials(options.allow_credentials);
response.exposed_headers(
options
.expose_headers
.iter()
.map(|s| &**s)
.collect::<Vec<&str>>()
.as_slice(),
)
}
pub fn catch_all_options_routes() -> Vec<rocket::Route> {
vec![rocket::Route::ranked(
isize::MAX,
http::Method::Options,
"/<catch_all_options_route..>",
CatchAllOptionsRouteHandler {},
)]
}
#[derive(Clone)]
struct CatchAllOptionsRouteHandler {}
#[rocket::async_trait]
impl rocket::route::Handler for CatchAllOptionsRouteHandler {
async fn handle<'r>(
&self,
request: &'r Request<'_>,
_: rocket::Data<'r>,
) -> rocket::route::Outcome<'r> {
let _ = &__arg2;
let guard: Guard<'_> = match request.guard().await {
Outcome::Success(guard) => guard,
Outcome::Failure((status, _)) => return rocket::route::Outcome::failure(status),
Outcome::Forward(()) => unreachable!("Should not be reachable"),
};
info_!(
"\"Catch all\" handling of CORS `OPTIONS` preflight for request {}",
request
);
rocket::route::Outcome::from(request, guard.responder(()))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::Header;
use rocket::local::blocking::Client;
use super::*;
use crate::http::Method;
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
Origin::from_str(origin.as_ref())
}
fn make_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions {
allowed_origins,
allowed_methods: vec![http::Method::Get]
.into_iter()
.map(From::from)
.collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
expose_headers: ["Content-Type", "X-Custom"]
.iter()
.map(|s| (*s).to_string())
.collect(),
..Default::default()
}
}
fn make_invalid_options() -> CorsOptions {
let mut cors = make_cors_options();
cors.allow_credentials = true;
cors.allowed_origins = AllOrSome::All;
cors.send_wildcard = true;
cors
}
fn make_client() -> Client {
let rocket = rocket::build();
Client::tracked(rocket).expect("valid rocket instance")
}
#[test]
fn cors_is_validated() {
assert!(make_cors_options().validate().is_ok())
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn cors_validates_illegal_allow_credentials() {
let cors = make_invalid_options();
cors.validate().unwrap();
}
#[test]
fn cors_options_from_builder_pattern() {
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
let cors_options_from_builder = CorsOptions::default()
.allowed_origins(allowed_origins)
.allowed_methods(
vec![http::Method::Get]
.into_iter()
.map(From::from)
.collect(),
)
.allowed_headers(AllowedHeaders::some(&["Authorization", "Accept"]))
.allow_credentials(true)
.expose_headers(
["Content-Type", "X-Custom"]
.iter()
.map(|s| (*s).to_string())
.collect(),
);
assert_eq!(cors_options_from_builder, make_cors_options());
}
#[cfg(feature = "serialization")]
#[test]
fn cors_default_deserialization_is_correct() {
let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail");
assert_eq!(deserialized, CorsOptions::default());
let expected_json = r#"
{
"allowed_origins": "All",
"allowed_methods": [
"POST",
"PATCH",
"PUT",
"DELETE",
"HEAD",
"OPTIONS",
"GET"
],
"allowed_headers": "All",
"allow_credentials": false,
"expose_headers": [],
"max_age": null,
"send_wildcard": false,
"fairing_route_base": "/cors",
"fairing_route_rank": 0
}
"#;
let actual: CorsOptions = serde_json::from_str(expected_json).expect("to not fail");
assert_eq!(actual, CorsOptions::default());
}
#[cfg(feature = "serialization")]
#[test]
fn cors_options_example_can_be_deserialized() {
let json = r#"{
"allowed_origins": {
"Some": {
"exact": ["https://www.acme.com"],
"regex": ["^https://www.example-[A-z0-9]*.com$"]
}
},
"allowed_methods": [
"POST",
"DELETE",
"GET"
],
"allowed_headers": {
"Some": [
"Accept",
"Authorization"
]
},
"allow_credentials": true,
"expose_headers": [
"Content-Type",
"X-Custom"
],
"max_age": 42,
"send_wildcard": false,
"fairing_route_base": "/mycors"
}"#;
let _: CorsOptions = serde_json::from_str(json).expect("to not fail");
}
#[test]
fn allowed_some_origins_allows_different_lifetimes() {
let static_exact = ["http://www.example.com"];
let random_allocation = vec![1, 2, 3];
let port: *const Vec<i32> = &random_allocation;
let port = port as u16;
let random_regex = vec![format!("https://(.+):{}", port)];
let _ = AllowedOrigins::some(&static_exact, &random_regex);
}
#[test]
fn allowed_origins_are_parsed_correctly() {
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some(
&["https://www.acme.com"],
&["^https://www.example-[A-z0-9]+.com$"]
)));
assert!(allowed_origins.is_some());
let expected_exact: HashSet<url::Origin> = [url::Url::from_str("https://www.acme.com")
.expect("not to fail")
.origin()]
.iter()
.map(Clone::clone)
.collect();
let expected_regex = ["^https://www.example-[A-z0-9]+.com$"];
let actual = allowed_origins.unwrap();
assert_eq!(expected_exact, actual.exact);
assert_eq!(expected_regex, actual.regex.expect("to be some").patterns());
}
#[test]
fn allowed_origins_errors_on_opaque_exact() {
let error = parse_allowed_origins(&AllowedOrigins::some::<_, &str>(
&[
"chrome-extension://something",
"moz-extension://something",
"https://valid.com",
],
&[],
))
.unwrap_err();
match error {
Error::OpaqueAllowedOrigin(mut origins) => {
origins.sort();
assert_eq!(
origins,
["chrome-extension://something", "moz-extension://something"]
);
}
others => {
panic!("Unexpected error: {:#?}", others);
}
};
}
#[test]
fn validate_origin_allows_all_origins() {
let url = "https://www.example.com";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = AllOrSome::All;
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_allows_origin() {
let url = "https://www.example.com";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com"
])));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_handles_punycode_properly() {
let cases = vec![
("https://аpple.com", "https://аpple.com"),
("https://аpple.com", "https://xn--pple-43d.com"),
("https://xn--pple-43d.com", "https://аpple.com"),
("https://xn--pple-43d.com", "https://xn--pple-43d.com"),
];
for (url, allowed_origin) in cases {
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some_exact(&[
allowed_origin
])));
not_err!(validate_origin(&origin, &allowed_origins));
}
}
#[test]
fn validate_origin_validates_regex() {
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some_regex(&[
"^https://www.example-[A-z0-9]+.com$",
"^https://(.+).acme.com$",
])));
let url = "https://www.example-something.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
let url = "https://subdomain.acme.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_validates_opaque_origins() {
let url = "moz-extension://8c7c4444-e29f-…cb8-1ade813dbd12/js/content.js:505";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some_regex(&[
"moz-extension://.*"
])));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_validates_mixed_settings() {
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some(
&["https://www.acme.com"],
&["^https://www.example-[A-z0-9]+.com$"]
)));
let url = "https://www.example-something123.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
let url = "https://www.acme.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() {
let url = "https://www.acme.com";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_allowed_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com"
])));
validate_origin(&origin, &allowed_origins).unwrap();
}
#[test]
fn response_sets_allow_origin_without_vary_correctly() {
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let expected_header = vec!["https://www.example.com"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
assert!(response.headers().get("Vary").next().is_none());
}
#[test]
fn response_sets_allow_origin_with_vary_correctly() {
let response = Response::new();
let response = response.origin("https://www.example.com", true);
let expected_header = vec!["https://www.example.com"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_sets_any_origin_correctly() {
let response = Response::new();
let response = response.any();
let expected_header = vec!["*"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.exposed_headers(&headers);
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Expose-Headers")
.collect();
assert_eq!(1, actual_header.len());
let mut actual_headers: Vec<String> = actual_header[0]
.split(',')
.map(|header| header.trim().to_string())
.collect();
actual_headers.sort();
assert_eq!(headers, actual_headers);
}
#[test]
fn response_sets_max_age_correctly() {
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.max_age(Some(42));
let expected_header = vec!["42"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_does_not_set_max_age_when_none() {
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.max_age(None);
let response = response.response(response::Response::new());
assert!(response
.headers()
.get("Access-Control-Max-Age")
.next()
.is_none())
}
#[test]
fn allowed_methods_validated_correctly() {
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
.into_iter()
.map(From::from)
.collect();
let method = "GET";
not_err!(validate_allowed_method(
&FromStr::from_str(method).expect("not to fail"),
&allowed_methods,
));
}
#[test]
#[should_panic(expected = "MethodNotAllowed")]
fn allowed_methods_errors_on_disallowed_method() {
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
.into_iter()
.map(From::from)
.collect();
let method = "DELETE";
validate_allowed_method(
&FromStr::from_str(method).expect("not to fail"),
&allowed_methods,
)
.unwrap()
}
#[test]
fn all_allowed_headers_are_validated_correctly() {
let allowed_headers = AllOrSome::All;
let requested_headers = vec!["Bar", "Foo"];
not_err!(validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&allowed_headers,
));
}
#[test]
fn allowed_headers_are_validated_correctly() {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo"];
not_err!(validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&AllOrSome::Some(
allowed_headers
.iter()
.map(|s| FromStr::from_str(*s).unwrap())
.collect(),
),
));
}
#[test]
#[should_panic(expected = "HeadersNotAllowed")]
fn allowed_headers_errors_on_non_subset() {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo", "Unknown"];
validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&AllOrSome::Some(
allowed_headers
.iter()
.map(|s| FromStr::from_str(*s).unwrap())
.collect(),
),
)
.unwrap();
}
#[test]
fn response_does_not_build_if_origin_is_not_set() {
let response = Response::new();
let response = response.response(response::Response::new());
assert_eq!(response.headers().iter().count(), 0);
}
#[test]
fn response_build_removes_existing_cors_headers_and_keeps_others() {
use std::io::Cursor;
let body = "Brewing the best coffee!";
let original = response::Response::build()
.status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket")
.raw_header("Access-Control-Max-Age", "42")
.sized_body(body.len(), Cursor::new(body))
.finalize();
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.response(original);
let expected_header = vec!["https://www.example.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
let expected_header = vec!["Rocket"];
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
assert_eq!(expected_header, actual_header);
assert!(response
.headers()
.get("Access-Control-Max-Age")
.next()
.is_none());
}
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
struct MethodTest {
method: crate::Method,
}
#[cfg(feature = "serialization")]
#[test]
fn method_serde_roundtrip() {
use serde_test::{assert_tokens, Token};
let test = MethodTest {
method: From::from(http::Method::Get),
};
assert_tokens(
&test,
&[
Token::Struct {
name: "MethodTest",
len: 1,
},
Token::Str("method"),
Token::Str("GET"),
Token::StructEnd,
],
);
}
#[test]
fn preflight_validated_correctly() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight {
origin: "https://www.acme.com".to_string(),
headers: Some(FromStr::from_str("Authorization").unwrap()),
};
assert_eq!(expected_result, result);
}
#[test]
fn preflight_validation_allows_all_origin() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight {
origin: "https://www.example.com".to_string(),
headers: Some(FromStr::from_str("Authorization").unwrap()),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn preflight_validation_errors_on_invalid_origin() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&cors, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "MissingRequestMethod")]
fn preflight_validation_errors_on_missing_request_method() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(request_headers);
let _ = validate(&cors, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "MethodNotAllowed")]
fn preflight_validation_errors_on_disallowed_method() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&cors, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "HeadersNotAllowed")]
fn preflight_validation_errors_on_disallowed_headers() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization, X-NOT-ALLOWED",
);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&cors, request.inner()).unwrap();
}
#[test]
fn actual_request_validated_correctly() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request {
origin: "https://www.acme.com".to_string(),
};
assert_eq!(expected_result, result);
}
#[test]
fn actual_request_validation_allows_all_origin() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
let request = client.get("/").header(origin_header);
let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request {
origin: "https://www.example.com".to_string(),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn actual_request_validation_errors_on_incorrect_origin() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
let request = client.get("/").header(origin_header);
let _ = validate(&cors, request.inner()).unwrap();
}
#[test]
fn non_cors_request_return_empty_response() {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let request = client.options("/");
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new();
assert_eq!(expected_response, response);
}
#[test]
fn preflight_validated_and_built_correctly() {
let options = make_cors_options();
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com", false)
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
#[test]
fn preflight_all_origins_with_vary() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com", true)
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
#[test]
fn preflight_all_origins_with_wildcard() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.any()
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_validated_and_built_correctly() {
let options = make_cors_options();
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com", false)
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_all_origins_with_vary() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = false;
options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com", true)
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_all_origins_with_wildcard() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new()
.any()
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
}