1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
//! A CORS middleware for Iron.
//!
//! See https://www.html5rocks.com/static/images/cors_server_flowchart.png for
//! reference.
//!
//! The middleware will return `HTTP 400 Bad Request` if the Origin host is
//! missing or not allowed.
//!
//! Preflight requests are not yet supported.

extern crate iron;
#[macro_use] extern crate log;

use iron::{Request, Response, IronResult, AroundMiddleware, Handler};
use iron::{headers, status};

/// The struct that holds the CORS configuration.
pub struct CorsMiddleware {
    allowed_hosts: Option<Vec<String>>,
}

impl CorsMiddleware {
    /// Specify which origin hosts are allowed to access the resource.
    pub fn with_whitelist(allowed_hosts: Vec<String>) -> Self {
        CorsMiddleware { allowed_hosts: Some(allowed_hosts) }
    }

    /// Allow all origin to access the resource.
    pub fn with_allow_any() -> Self {
        CorsMiddleware { allowed_hosts: None }
    }
}

impl AroundMiddleware for CorsMiddleware {
    fn around(self, handler: Box<Handler>) -> Box<Handler> {
        // TODO: Can we prevent this allocation?
        Box::new(CorsHandler { handler: handler, allowed_hosts: self.allowed_hosts.clone() })
    }
}

struct CorsHandler {
    handler: Box<Handler>,
    allowed_hosts: Option<Vec<String>>,
}


/// The handler that acts as an AroundMiddleware.
///
/// It first checks an incoming request for appropriate CORS headers. If the
/// request is allowed, then process it as usual. If not, return a proper
/// response.
impl Handler for CorsHandler {
    fn handle(&self, req: &mut Request) -> IronResult<Response> {
        // Extract origin header
        let origin = match req.headers.get::<headers::Origin>() {
            Some(origin) => origin.clone(),
            None => {
                warn!("Not a valid CORS request: Missing Origin header");
                return Ok(Response::with((status::BadRequest, "Invalid CORS request: Origin header missing")));
            }
        };

        // Verify origin header
        let may_process = match self.allowed_hosts {
            Some(ref allowed_hosts) => allowed_hosts.contains(&origin.host.hostname),
            None => true,
        };

        // Process request
        if may_process {
            // Everything OK, process request
            let mut res = try!(self.handler.handle(req));

            // Add Access-Control-Allow-Origin header to response
            let header = match origin.host.port {
                Some(port) => format!("{}://{}:{}", &origin.scheme, &origin.host.hostname, &port),
                None => format!("{}://{}", &origin.scheme, &origin.host.hostname),
            };
            res.headers.set(headers::AccessControlAllowOrigin::Value(header));

            Ok(res)
        } else {
            warn!("Got disallowed CORS request from {}", &origin.host.hostname);
            Ok(Response::with((status::BadRequest, "Invalid CORS request: Origin not allowed")))
        }
    }
}