1#[macro_use]
7extern crate gotham_derive;
8
9extern crate futures;
10extern crate gotham;
11extern crate hyper;
12extern crate unicase;
13
14use futures::Future;
15use gotham::handler::HandlerFuture;
16use gotham::middleware::Middleware;
17use gotham::state::{FromState, State};
18use hyper::header::{
19 AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods,
20 AccessControlAllowOrigin, AccessControlMaxAge, Headers, Origin,
21};
22use hyper::Method;
23use std::option::Option;
24use unicase::Ascii;
25
26#[derive(Clone, NewMiddleware, Debug, PartialEq)]
55pub struct CORSMiddleware {
56 methods: Vec<Method>,
57 origin: Option<String>,
58 max_age: u32,
59}
60
61impl CORSMiddleware {
62 pub fn new(methods: Vec<Method>, origin: Option<String>, max_age: u32) -> CORSMiddleware {
108 CORSMiddleware {
109 methods,
110 origin,
111 max_age,
112 }
113 }
114
115 pub fn default() -> CORSMiddleware {
122 let methods = vec![
123 Method::Delete,
124 Method::Get,
125 Method::Head,
126 Method::Options,
127 Method::Patch,
128 Method::Post,
129 Method::Put,
130 ];
131
132 let origin = None;
133 let max_age = 86400;
134
135 CORSMiddleware::new(methods, origin, max_age)
136 }
137}
138
139impl Middleware for CORSMiddleware {
140 fn call<Chain>(self, state: State, chain: Chain) -> Box<HandlerFuture>
141 where
142 Chain: FnOnce(State) -> Box<HandlerFuture>,
143 {
144 let settings = self.clone();
145 let f = chain(state).map(|(state, response)| {
146 let origin: String;
147 if settings.origin.is_none() {
148 let origin_raw = Headers::borrow_from(&state).get::<Origin>().clone();
149 let ori = match origin_raw {
150 Some(o) => o.to_string(),
151 None => "*".to_string(),
152 };
153
154 origin = ori;
155 } else {
156 origin = settings.origin.unwrap();
157 };
158
159 let mut headers = Headers::new();
160
161 headers.set(AccessControlAllowCredentials);
162 headers.set(AccessControlAllowHeaders(vec![
163 Ascii::new("Authorization".to_string()),
164 Ascii::new("Content-Type".to_string()),
165 ]));
166 headers.set(AccessControlAllowOrigin::Value(origin));
167 headers.set(AccessControlAllowMethods(settings.methods));
168 headers.set(AccessControlMaxAge(settings.max_age));
169
170 let res = response.with_headers(headers);
171
172 (state, res)
173 });
174
175 Box::new(f)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 extern crate mime;
182
183 use super::*;
184
185 use futures::future;
186 use gotham::http::response::create_response;
187 use gotham::pipeline::new_pipeline;
188 use gotham::pipeline::single::single_pipeline;
189 use gotham::router::builder::*;
190 use gotham::router::Router;
191 use gotham::test::TestServer;
192 use hyper::Method::Options;
193 use hyper::StatusCode;
194 use hyper::{Get, Head};
195
196 fn handler(state: State) -> Box<HandlerFuture> {
198 let body = "Hello World".to_string();
199
200 let response = create_response(
201 &state,
202 StatusCode::Ok,
203 Some((body.into_bytes(), mime::TEXT_PLAIN)),
204 );
205
206 Box::new(future::ok((state, response)))
207 }
208
209 fn default_router() -> Router {
210 let (chain, pipeline) =
211 single_pipeline(new_pipeline().add(CORSMiddleware::default()).build());
212
213 build_router(chain, pipeline, |route| {
214 route.request(vec![Get, Head, Options], "/").to(handler);
215 })
216 }
217
218 fn custom_router() -> Router {
219 let methods = vec![Method::Delete, Method::Get, Method::Head, Method::Options];
220
221 let max_age = 1000;
222
223 let origin = Some("http://www.example.com".to_string());
224
225 let (chain, pipeline) = single_pipeline(
226 new_pipeline()
227 .add(CORSMiddleware::new(methods, origin, max_age))
228 .build(),
229 );
230
231 build_router(chain, pipeline, |route| {
232 route.request(vec![Get, Head, Options], "/").to(handler);
233 })
234 }
235
236 #[test]
237 fn test_headers_set() {
238 let test_server = TestServer::new(default_router()).unwrap();
239
240 let response = test_server
241 .client()
242 .get("https://example.com/")
243 .perform()
244 .unwrap();
245
246 assert_eq!(response.status(), StatusCode::Ok);
247 let headers = response.headers();
248 assert_eq!(
249 headers
250 .get::<AccessControlAllowOrigin>()
251 .unwrap()
252 .to_string(),
253 "*".to_string()
254 );
255 assert_eq!(
256 headers.get::<AccessControlMaxAge>().unwrap().to_string(),
257 "86400".to_string()
258 );
259 }
260
261 #[test]
262 fn test_custom_headers_set() {
263 let test_server = TestServer::new(custom_router()).unwrap();
264
265 let response = test_server
266 .client()
267 .get("https://example.com/")
268 .perform()
269 .unwrap();
270
271 assert_eq!(response.status(), StatusCode::Ok);
272 let headers = response.headers();
273 assert_eq!(
274 headers
275 .get::<AccessControlAllowOrigin>()
276 .unwrap()
277 .to_string(),
278 "http://www.example.com".to_string()
279 );
280 assert_eq!(
281 headers.get::<AccessControlMaxAge>().unwrap().to_string(),
282 "1000".to_string()
283 );
284 }
285
286 #[test]
287 fn test_new_cors_middleware() {
288 let methods = vec![Method::Delete, Method::Get, Method::Head, Method::Options];
289
290 let max_age = 1000;
291
292 let origin = Some("http://www.example.com".to_string());
293
294 let test = CORSMiddleware::new(methods.clone(), origin.clone(), max_age.clone());
295
296 let default = CORSMiddleware::default();
297
298 assert_ne!(test, default);
299
300 assert_eq!(test.origin, origin);
301 assert_eq!(test.max_age, max_age);
302 assert_eq!(test.methods, methods);
303 }
304
305 #[test]
306 fn test_default_cors_middleware() {
307 let test = CORSMiddleware::default();
308 let methods = vec![
309 Method::Delete,
310 Method::Get,
311 Method::Head,
312 Method::Options,
313 Method::Patch,
314 Method::Post,
315 Method::Put,
316 ];
317
318 assert_eq!(test.methods, methods);
319
320 assert_eq!(test.max_age, 86400);
321
322 assert_eq!(test.origin, None);
323 }
324}