gotham_cors_middleware/
lib.rs

1//! Library aimed at providing CORS functionality
2//! for Gotham based servers.
3//!
4//! Currently a very basic implementation with
5//! limited customisability.
6#[macro_use]
7extern crate gotham_derive;
8
9extern crate futures;
10extern crate gotham;
11extern crate hyper;
12extern crate unicase;
13
14use futures::Future;
15use gotham::handler::HandlerFuture;
16use gotham::middleware::Middleware;
17use gotham::state::{FromState, State};
18use hyper::header::{
19    AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods,
20    AccessControlAllowOrigin, AccessControlMaxAge, Headers, Origin,
21};
22use hyper::Method;
23use std::option::Option;
24use unicase::Ascii;
25
26/// Struct to perform the necessary CORS
27/// functionality needed. Allows some
28/// customisation through use of the
29/// new() function.
30///
31/// Example of use:
32/// ```rust
33/// extern crate gotham;
34/// extern crate gotham_cors_middleware;
35///
36/// use gotham::pipeline::new_pipeline;
37/// use gotham_cors_middleware::CORSMiddleware;
38/// use gotham::pipeline::single::single_pipeline;
39/// use gotham::router::builder::*;
40/// use gotham::router::Router;
41///
42/// pub fn router() -> Router {
43///     let (chain, pipeline) = single_pipeline(
44///         new_pipeline()
45///             .add(CORSMiddleware::default())
46///             .build()
47///     );
48///
49///     build_router(chain, pipeline, |route| {
50///         // Routes
51///     })
52/// }
53/// ```
54#[derive(Clone, NewMiddleware, Debug, PartialEq)]
55pub struct CORSMiddleware {
56    methods: Vec<Method>,
57    origin: Option<String>,
58    max_age: u32,
59}
60
61impl CORSMiddleware {
62    /// Create a new CORSMiddleware with custom methods,
63    /// origin and max_age properties.
64    ///
65    /// Expects methods to be a Vec of hyper::Method enum
66    /// values, origin to be an Option containing a String
67    /// (so allows for None values - which defaults to
68    /// returning the sender origin on request or returning
69    /// a string of "*" - see the call function source) and
70    /// max age to be a u32 value.
71    ///
72    /// Example of use:
73    /// ```rust
74    /// extern crate gotham;
75    /// extern crate gotham_cors_middleware;
76    /// extern crate hyper;
77    ///
78    /// use gotham::pipeline::new_pipeline;
79    /// use gotham_cors_middleware::CORSMiddleware;
80    /// use gotham::pipeline::single::single_pipeline;
81    /// use gotham::router::builder::*;
82    /// use gotham::router::Router;
83    /// use hyper::Method;
84    ///
85    /// fn create_custom_middleware() -> CORSMiddleware {
86    ///     let methods = vec![Method::Delete, Method::Get, Method::Head, Method::Options];
87    ///
88    ///     let max_age = 1000;
89    ///
90    ///     let origin = Some("http://www.example.com".to_string());
91    ///
92    ///     CORSMiddleware::new(methods, origin, max_age)
93    /// }
94    ///
95    /// pub fn router() -> Router {
96    ///     let (chain, pipeline) = single_pipeline(
97    ///         new_pipeline()
98    ///             .add(create_custom_middleware())
99    ///             .build()
100    ///     );
101    ///
102    ///     build_router(chain, pipeline, |route| {
103    ///         // Routes
104    ///     })
105    /// }
106    /// ```
107    pub fn new(methods: Vec<Method>, origin: Option<String>, max_age: u32) -> CORSMiddleware {
108        CORSMiddleware {
109            methods,
110            origin,
111            max_age,
112        }
113    }
114
115    /// Creates a new CORSMiddleware with what is currently
116    /// the "default" values for methods/origin/max_age.
117    ///
118    /// This is based off the values that were used previously
119    /// before they were customisable. If you need different
120    /// values, use the new() function.
121    pub fn default() -> CORSMiddleware {
122        let methods = vec![
123            Method::Delete,
124            Method::Get,
125            Method::Head,
126            Method::Options,
127            Method::Patch,
128            Method::Post,
129            Method::Put,
130        ];
131
132        let origin = None;
133        let max_age = 86400;
134
135        CORSMiddleware::new(methods, origin, max_age)
136    }
137}
138
139impl Middleware for CORSMiddleware {
140    fn call<Chain>(self, state: State, chain: Chain) -> Box<HandlerFuture>
141    where
142        Chain: FnOnce(State) -> Box<HandlerFuture>,
143    {
144        let settings = self.clone();
145        let f = chain(state).map(|(state, response)| {
146            let origin: String;
147            if settings.origin.is_none() {
148                let origin_raw = Headers::borrow_from(&state).get::<Origin>().clone();
149                let ori = match origin_raw {
150                    Some(o) => o.to_string(),
151                    None => "*".to_string(),
152                };
153
154                origin = ori;
155            } else {
156                origin = settings.origin.unwrap();
157            };
158
159            let mut headers = Headers::new();
160
161            headers.set(AccessControlAllowCredentials);
162            headers.set(AccessControlAllowHeaders(vec![
163                Ascii::new("Authorization".to_string()),
164                Ascii::new("Content-Type".to_string()),
165            ]));
166            headers.set(AccessControlAllowOrigin::Value(origin));
167            headers.set(AccessControlAllowMethods(settings.methods));
168            headers.set(AccessControlMaxAge(settings.max_age));
169
170            let res = response.with_headers(headers);
171
172            (state, res)
173        });
174
175        Box::new(f)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    extern crate mime;
182
183    use super::*;
184
185    use futures::future;
186    use gotham::http::response::create_response;
187    use gotham::pipeline::new_pipeline;
188    use gotham::pipeline::single::single_pipeline;
189    use gotham::router::builder::*;
190    use gotham::router::Router;
191    use gotham::test::TestServer;
192    use hyper::Method::Options;
193    use hyper::StatusCode;
194    use hyper::{Get, Head};
195
196    // Since we cannot construct 'State' ourselves, we need to test via an 'actual' app
197    fn handler(state: State) -> Box<HandlerFuture> {
198        let body = "Hello World".to_string();
199
200        let response = create_response(
201            &state,
202            StatusCode::Ok,
203            Some((body.into_bytes(), mime::TEXT_PLAIN)),
204        );
205
206        Box::new(future::ok((state, response)))
207    }
208
209    fn default_router() -> Router {
210        let (chain, pipeline) =
211            single_pipeline(new_pipeline().add(CORSMiddleware::default()).build());
212
213        build_router(chain, pipeline, |route| {
214            route.request(vec![Get, Head, Options], "/").to(handler);
215        })
216    }
217
218    fn custom_router() -> Router {
219        let methods = vec![Method::Delete, Method::Get, Method::Head, Method::Options];
220
221        let max_age = 1000;
222
223        let origin = Some("http://www.example.com".to_string());
224
225        let (chain, pipeline) = single_pipeline(
226            new_pipeline()
227                .add(CORSMiddleware::new(methods, origin, max_age))
228                .build(),
229        );
230
231        build_router(chain, pipeline, |route| {
232            route.request(vec![Get, Head, Options], "/").to(handler);
233        })
234    }
235
236    #[test]
237    fn test_headers_set() {
238        let test_server = TestServer::new(default_router()).unwrap();
239
240        let response = test_server
241            .client()
242            .get("https://example.com/")
243            .perform()
244            .unwrap();
245
246        assert_eq!(response.status(), StatusCode::Ok);
247        let headers = response.headers();
248        assert_eq!(
249            headers
250                .get::<AccessControlAllowOrigin>()
251                .unwrap()
252                .to_string(),
253            "*".to_string()
254        );
255        assert_eq!(
256            headers.get::<AccessControlMaxAge>().unwrap().to_string(),
257            "86400".to_string()
258        );
259    }
260
261    #[test]
262    fn test_custom_headers_set() {
263        let test_server = TestServer::new(custom_router()).unwrap();
264
265        let response = test_server
266            .client()
267            .get("https://example.com/")
268            .perform()
269            .unwrap();
270
271        assert_eq!(response.status(), StatusCode::Ok);
272        let headers = response.headers();
273        assert_eq!(
274            headers
275                .get::<AccessControlAllowOrigin>()
276                .unwrap()
277                .to_string(),
278            "http://www.example.com".to_string()
279        );
280        assert_eq!(
281            headers.get::<AccessControlMaxAge>().unwrap().to_string(),
282            "1000".to_string()
283        );
284    }
285
286    #[test]
287    fn test_new_cors_middleware() {
288        let methods = vec![Method::Delete, Method::Get, Method::Head, Method::Options];
289
290        let max_age = 1000;
291
292        let origin = Some("http://www.example.com".to_string());
293
294        let test = CORSMiddleware::new(methods.clone(), origin.clone(), max_age.clone());
295
296        let default = CORSMiddleware::default();
297
298        assert_ne!(test, default);
299
300        assert_eq!(test.origin, origin);
301        assert_eq!(test.max_age, max_age);
302        assert_eq!(test.methods, methods);
303    }
304
305    #[test]
306    fn test_default_cors_middleware() {
307        let test = CORSMiddleware::default();
308        let methods = vec![
309            Method::Delete,
310            Method::Get,
311            Method::Head,
312            Method::Options,
313            Method::Patch,
314            Method::Post,
315            Method::Put,
316        ];
317
318        assert_eq!(test.methods, methods);
319
320        assert_eq!(test.max_age, 86400);
321
322        assert_eq!(test.origin, None);
323    }
324}