extern crate iron;
extern crate unicase;
#[macro_use]
extern crate hyper;
pub use unicase::UniCase;
use iron::prelude::*;
use iron::method::Method;
use iron::method::Method::*;
use iron::status;
use iron::headers::{AccessControlRequestMethod, AccessControlRequestHeaders,
AccessControlAllowOrigin, AccessControlAllowHeaders, AccessControlMaxAge,
AccessControlAllowMethods, AccessControlAllowCredentials,
AccessControlExposeHeaders, Vary};
use iron::middleware::{AroundMiddleware, Handler};
use std::collections::HashSet;
use std::iter::FromIterator;
pub use origin::Origin;
mod origin;
header! {
(OriginHeader, "Origin") => [String]
}
#[derive(Clone)]
pub enum AllowedOrigins {
Any {
allow_null: bool,
},
Specific(HashSet<Origin>),
}
impl AllowedOrigins {
fn allow(&self,
origin_string: &str,
prefer_wildcard: bool,
allow_credentials: bool)
-> Option<String> {
{
if allow_credentials {
Some(origin_string.to_owned())
} else {
Some(if prefer_wildcard {
"*".to_owned()
} else {
origin_string.to_owned()
})
}
}
}
pub fn allowed_for(&self,
origin_string: &str,
allow_credentials: bool,
prefer_wildcard: bool)
-> Option<String> {
match Origin::parse_allow_null(origin_string) {
Err(_) => None,
Ok(origin) => {
match *self {
AllowedOrigins::Any { allow_null } => {
if origin == Origin::Null && !allow_null {
None
} else {
self.allow(origin_string, prefer_wildcard, allow_credentials)
}
}
AllowedOrigins::Specific(ref allowed) => {
if allowed.contains(&origin) {
self.allow(origin_string, prefer_wildcard, allow_credentials)
} else {
None
}
}
}
}
}
}
}
#[derive(Clone)]
pub struct CorsMiddleware {
pub allowed_origins: AllowedOrigins,
pub allowed_methods: Vec<Method>,
pub allowed_headers: Vec<UniCase<String>>,
pub exposed_headers: Vec<UniCase<String>>,
pub allow_credentials: bool,
pub max_age_seconds: u32,
pub prefer_wildcard: bool,
}
pub fn all_std_methods() -> Vec<Method> {
vec![Options, Get, Post, Put, Delete, Head, Trace, Connect, Patch]
}
pub fn common_req_headers() -> Vec<unicase::UniCase<String>> {
vec![UniCase("Authorization".to_owned()),
UniCase("Content-Type".to_owned()),
UniCase("X-Requested-With".to_owned())]
}
impl CorsMiddleware {
pub fn permissive() -> CorsMiddleware {
CorsMiddleware {
allowed_origins: AllowedOrigins::Any { allow_null: false },
allowed_methods: all_std_methods(),
allowed_headers: common_req_headers(),
exposed_headers: vec![],
allow_credentials: false,
max_age_seconds: 60 * 60,
prefer_wildcard: false,
}
}
fn vary_headers() -> Vec<UniCase<String>> {
vec![UniCase("Origin".to_owned()),
UniCase("Access-Control-Request-Method".to_owned()),
UniCase("Access-Control-Request-Headers".to_owned())]
}
fn handle(&self, req: &mut Request, handler: &Handler) -> IronResult<Response> {
let res = if req.method == Options &&
req.headers.get::<AccessControlRequestMethod>().is_some() {
self.handle_preflight(req, handler)
} else {
self.handle_normal(req, handler)
};
match res {
Ok(mut r) => {
r.headers.set(Vary::Items(CorsMiddleware::vary_headers()));
Ok(r)
}
x => x,
}
}
fn handle_preflight(&self, req: &mut Request, _: &Handler) -> IronResult<Response> {
let mut res = Response::with(status::NoContent);
let maybe_origin = req.headers.get::<OriginHeader>();
if maybe_origin.is_none() {
let resp = Response::with((status::BadRequest,
"Preflight request without Origin header"));
return Ok(resp);
}
let origin = maybe_origin.unwrap();
let origin_str = origin.to_string();
let allowed_origin =
self.allowed_origins.allowed_for(&origin_str,
self.allow_credentials,
self.prefer_wildcard);
if allowed_origin.is_none() {
let resp = Response::with((status::BadRequest,
format!("Preflight request requesting \
disallowed origin '{}'",
origin_str)));
return Ok(resp);
}
let requested_method = req.headers.get::<AccessControlRequestMethod>().unwrap();
let empty_vec: Vec<UniCase<String>> = vec![];
let maybe_requested_headers = req.headers.get::<AccessControlRequestHeaders>();
let requested_headers: &Vec<UniCase<String>> = if maybe_requested_headers.is_some() {
&maybe_requested_headers.unwrap().0
} else {
&empty_vec
};
if !self.allowed_methods.contains(requested_method) {
return Ok(Response::with((status::BadRequest,
format!("Preflight request requesting disallowed method {}",
requested_method))));
}
let requested_headers_set: HashSet<UniCase<String>> =
HashSet::from_iter(requested_headers.iter().cloned());
let allowed_headers_set: HashSet<UniCase<String>> =
HashSet::from_iter(self.allowed_headers.iter().cloned());
let disallowed_headers: HashSet<UniCase<String>> =
requested_headers_set.difference(&allowed_headers_set).cloned().collect();
if !disallowed_headers.is_empty() {
let a = disallowed_headers.iter()
.map(|uh| uh.to_string())
.collect::<Vec<_>>()
.join(",");
let msg = format!("Preflight request requesting disallowed header(s) {}", a);
return Ok(Response::with((status::BadRequest, msg)));
}
if self.allow_credentials {
res.headers.set(AccessControlAllowCredentials);
}
res.headers.set(AccessControlAllowOrigin::Value(allowed_origin.unwrap()));
res.headers.set(AccessControlMaxAge(self.max_age_seconds));
res.headers.set(AccessControlAllowMethods(self.allowed_methods.clone()));
res.headers.set(AccessControlAllowHeaders(self.allowed_headers.clone()));
Ok(res)
}
fn handle_normal(&self, req: &mut Request, handler: &Handler) -> IronResult<Response> {
let have_origin;
{
let maybe_origin = req.headers.get::<OriginHeader>();
have_origin = maybe_origin.is_some();
}
if !have_origin {
return handler.handle(req);
}
let origin = req.headers
.get::<OriginHeader>()
.unwrap()
.clone();
let origin_str = origin.to_string();
let allowed_origin =
self.allowed_origins.allowed_for(&origin_str,
self.allow_credentials,
self.prefer_wildcard);
if allowed_origin.is_none() {
let resp = Response::with((status::BadRequest,
format!("Normal request requesting \
disallowed origin '{}'",
origin_str)));
return Ok(resp);
}
let result = handler.handle(req);
match result {
Ok(mut res) => {
if self.allow_credentials {
res.headers.set(AccessControlAllowCredentials);
}
res.headers.set(AccessControlAllowOrigin::Value(allowed_origin.unwrap()));
if !self.exposed_headers.is_empty() {
res.headers.set(AccessControlExposeHeaders(self.exposed_headers.clone()));
}
Ok(res)
}
_ => result,
}
}
pub fn decorate<T: Handler>(self, handler: T) -> Chain {
let mut chain = Chain::new(handler);
chain.link_around(self);
chain
}
}
impl AroundMiddleware for CorsMiddleware {
fn around(self, handler: Box<Handler>) -> Box<Handler> {
Box::new(move |req: &mut Request| self.handle(req, &handler))
}
}