1extern crate unicase;
3
4use std::{fmt, ops};
5use hosts::{Host, Port};
6use matcher::{Matcher, Pattern};
7use std::collections::HashSet;
8pub use self::unicase::Ascii;
9
10#[derive(Clone, Hash, Debug, PartialEq, Eq)]
12pub enum OriginProtocol {
13 Http,
15 Https,
17 Custom(String),
19}
20
21#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Origin {
24 protocol: OriginProtocol,
25 host: Host,
26 as_string: String,
27 matcher: Matcher,
28}
29
30impl<T: AsRef<str>> From<T> for Origin {
31 fn from(string: T) -> Self {
32 Origin::parse(string.as_ref())
33 }
34}
35
36impl Origin {
37 fn with_host(protocol: OriginProtocol, host: Host) -> Self {
38 let string = Self::to_string(&protocol, &host);
39 let matcher = Matcher::new(&string);
40
41 Origin {
42 protocol: protocol,
43 host: host,
44 as_string: string,
45 matcher: matcher,
46 }
47 }
48
49 pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
52 Self::with_host(protocol, Host::new(host, port))
53 }
54
55 pub fn parse(data: &str) -> Self {
58 let mut it = data.split("://");
59 let proto = it.next().expect("split always returns non-empty iterator.");
60 let hostname = it.next();
61
62 let (proto, hostname) = match hostname {
63 None => (None, proto),
64 Some(hostname) => (Some(proto), hostname),
65 };
66
67 let proto = proto.map(str::to_lowercase);
68 let hostname = Host::parse(hostname);
69
70 let protocol = match proto {
71 None => OriginProtocol::Http,
72 Some(ref p) if p == "http" => OriginProtocol::Http,
73 Some(ref p) if p == "https" => OriginProtocol::Https,
74 Some(other) => OriginProtocol::Custom(other),
75 };
76
77 Origin::with_host(protocol, hostname)
78 }
79
80 fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
81 format!(
82 "{}://{}",
83 match *protocol {
84 OriginProtocol::Http => "http",
85 OriginProtocol::Https => "https",
86 OriginProtocol::Custom(ref protocol) => protocol,
87 },
88 &**host,
89 )
90 }
91}
92
93impl Pattern for Origin {
94 fn matches<T: AsRef<str>>(&self, other: T) -> bool {
95 self.matcher.matches(other)
96 }
97}
98
99impl ops::Deref for Origin {
100 type Target = str;
101 fn deref(&self) -> &Self::Target {
102 &self.as_string
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AccessControlAllowOrigin {
109 Value(Origin),
111 Null,
113 Any,
115}
116
117impl fmt::Display for AccessControlAllowOrigin {
118 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119 write!(f, "{}", match *self {
120 AccessControlAllowOrigin::Any => "*",
121 AccessControlAllowOrigin::Null => "null",
122 AccessControlAllowOrigin::Value(ref val) => val,
123 })
124 }
125}
126
127impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
128 fn from(s: T) -> AccessControlAllowOrigin {
129 match s.into().as_str() {
130 "all" | "*" | "any" => AccessControlAllowOrigin::Any,
131 "null" => AccessControlAllowOrigin::Null,
132 origin => AccessControlAllowOrigin::Value(origin.into()),
133 }
134 }
135}
136
137#[derive(Debug, Clone, PartialEq)]
139pub enum AccessControlAllowHeaders {
140 Only(Vec<String>),
142 Any,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
148pub enum AllowCors<T> {
149 NotRequired,
151 Invalid,
153 Ok(T),
155}
156
157impl<T> AllowCors<T> {
158 pub fn map<F, O>(self, f: F) -> AllowCors<O> where
160 F: FnOnce(T) -> O,
161 {
162 use self::AllowCors::*;
163
164 match self {
165 NotRequired => NotRequired,
166 Invalid => Invalid,
167 Ok(val) => Ok(f(val)),
168 }
169 }
170}
171
172impl<T> Into<Option<T>> for AllowCors<T> {
173 fn into(self) -> Option<T> {
174 use self::AllowCors::*;
175
176 match self {
177 NotRequired | Invalid => None,
178 Ok(header) => Some(header),
179 }
180 }
181}
182
183pub fn get_cors_allow_origin(
185 origin: Option<&str>,
186 host: Option<&str>,
187 allowed: &Option<Vec<AccessControlAllowOrigin>>
188) -> AllowCors<AccessControlAllowOrigin> {
189 match origin {
190 None => AllowCors::NotRequired,
191 Some(ref origin) => {
192 if let Some(host) = host {
193 if origin.ends_with(host) {
195 let origin = Origin::parse(origin);
197 if &*origin.host == host {
198 return AllowCors::NotRequired;
199 }
200 }
201 }
202
203 match allowed.as_ref() {
204 None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
205 None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
206 Some(ref allowed) if *origin == "null" => {
207 allowed.iter().find(|cors| **cors == AccessControlAllowOrigin::Null).cloned()
208 .map(AllowCors::Ok)
209 .unwrap_or(AllowCors::Invalid)
210 },
211 Some(ref allowed) => {
212 allowed.iter().find(|cors| {
213 match **cors {
214 AccessControlAllowOrigin::Any => true,
215 AccessControlAllowOrigin::Value(ref val) if val.matches(origin) =>
216 {
217 true
218 },
219 _ => false
220 }
221 })
222 .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
223 .map(AllowCors::Ok).unwrap_or(AllowCors::Invalid)
224 },
225 }
226 },
227 }
228}
229
230pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
232 mut headers: impl Iterator<Item=T>,
233 requested_headers: impl Iterator<Item=T>,
234 cors_allow_headers: &AccessControlAllowHeaders,
235 to_result: F
236) -> AllowCors<Vec<O>> {
237 if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
239 let are_all_allowed = headers
240 .all(|header| {
241 let name = &Ascii::new(header.as_ref());
242 only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
243 });
244
245 if !are_all_allowed {
246 return AllowCors::Invalid;
247 }
248 }
249
250 let (filtered, headers) = match cors_allow_headers {
252 AccessControlAllowHeaders::Any => {
253 let headers = requested_headers.map(to_result).collect();
254 (false, headers)
255 },
256 AccessControlAllowHeaders::Only(only) => {
257 let mut filtered = false;
258 let headers: Vec<_> = requested_headers
259 .filter(|header| {
260 let name = &Ascii::new(header.as_ref());
261 filtered = true;
262 only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
263 })
264 .map(to_result)
265 .collect();
266
267 (filtered, headers)
268 },
269 };
270
271 if headers.is_empty() {
272 if filtered {
273 AllowCors::Invalid
274 } else {
275 AllowCors::NotRequired
276 }
277 } else {
278 AllowCors::Ok(headers)
279 }
280}
281
282lazy_static! {
284 static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
285 let mut hs = HashSet::new();
286 hs.insert(Ascii::new("Accept"));
287 hs.insert(Ascii::new("Accept-Language"));
288 hs.insert(Ascii::new("Access-Control-Allow-Origin"));
289 hs.insert(Ascii::new("Access-Control-Request-Headers"));
290 hs.insert(Ascii::new("Content-Language"));
291 hs.insert(Ascii::new("Content-Type"));
292 hs.insert(Ascii::new("Host"));
293 hs.insert(Ascii::new("Origin"));
294 hs.insert(Ascii::new("Content-Length"));
295 hs.insert(Ascii::new("Connection"));
296 hs.insert(Ascii::new("User-Agent"));
297 hs
298 };
299}
300
301#[cfg(test)]
302mod tests {
303 use std::iter;
304
305 use super::*;
306 use hosts::Host;
307
308 #[test]
309 fn should_parse_origin() {
310 use self::OriginProtocol::*;
311
312 assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
313 assert_eq!(Origin::parse("https://parity.io:8443"), Origin::new(Https, "parity.io", Some(8443)));
314 assert_eq!(Origin::parse("chrome-extension://124.0.0.1"), Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None));
315 assert_eq!(Origin::parse("parity.io/somepath"), Origin::new(Http, "parity.io", None));
316 assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545)));
317 }
318
319 #[test]
320 fn should_not_allow_partially_matching_origin() {
321 let origin1 = Origin::parse("http://subdomain.somedomain.io");
323 let origin2 = Origin::parse("http://somedomain.io:8080");
324 let host = Host::parse("http://somedomain.io");
325
326 let origin1 = Some(&*origin1);
327 let origin2 = Some(&*origin2);
328 let host = Some(&*host);
329
330 let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
332 let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
333
334 assert_eq!(res1, AllowCors::Invalid);
336 assert_eq!(res2, AllowCors::Invalid);
337 }
338
339 #[test]
340 fn should_allow_origins_that_matches_hosts() {
341 let origin = Origin::parse("http://127.0.0.1:8080");
343 let host = Host::parse("http://127.0.0.1:8080");
344
345 let origin = Some(&*origin);
346 let host = Some(&*host);
347
348 let res = get_cors_allow_origin(origin, host, &None);
350
351 assert_eq!(res, AllowCors::NotRequired);
353 }
354
355 #[test]
356 fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
357 let origin = None;
359 let host = None;
360
361 let res = get_cors_allow_origin(origin, host, &None);
363
364 assert_eq!(res, AllowCors::NotRequired);
366 }
367
368 #[test]
369 fn should_return_domain_when_all_are_allowed() {
370 let origin = Some("parity.io");
372 let host = None;
373
374 let res = get_cors_allow_origin(origin, host, &None);
376
377 assert_eq!(res, AllowCors::Ok("parity.io".into()));
379 }
380
381 #[test]
382 fn should_return_none_for_empty_origin() {
383 let origin = None;
385 let host = None;
386
387 let res = get_cors_allow_origin(
389 origin,
390 host,
391 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
392 );
393
394 assert_eq!(res, AllowCors::NotRequired);
396 }
397
398 #[test]
399 fn should_return_none_for_empty_list() {
400 let origin = None;
402 let host = None;
403
404 let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
406
407 assert_eq!(res, AllowCors::NotRequired);
409 }
410
411 #[test]
412 fn should_return_none_for_not_matching_origin() {
413 let origin = Some("http://parity.io".into());
415 let host = None;
416
417 let res = get_cors_allow_origin(
419 origin,
420 host,
421 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
422 );
423
424 assert_eq!(res, AllowCors::Invalid);
426 }
427
428 #[test]
429 fn should_return_specific_origin_if_we_allow_any() {
430 let origin = Some("http://parity.io".into());
432 let host = None;
433
434 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
436
437 assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
439 }
440
441 #[test]
442 fn should_return_none_if_origin_is_not_defined() {
443 let origin = None;
445 let host = None;
446
447 let res = get_cors_allow_origin(
449 origin,
450 host,
451 &Some(vec![AccessControlAllowOrigin::Null]),
452 );
453
454 assert_eq!(res, AllowCors::NotRequired);
456 }
457
458 #[test]
459 fn should_return_null_if_origin_is_null() {
460 let origin = Some("null".into());
462 let host = None;
463
464 let res = get_cors_allow_origin(
466 origin,
467 host,
468 &Some(vec![AccessControlAllowOrigin::Null]),
469 );
470
471 assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
473 }
474
475 #[test]
476 fn should_return_specific_origin_if_there_is_a_match() {
477 let origin = Some("http://parity.io".into());
479 let host = None;
480
481 let res = get_cors_allow_origin(
483 origin,
484 host,
485 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into()), AccessControlAllowOrigin::Value("http://parity.io".into())]),
486 );
487
488 assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
490 }
491
492 #[test]
493 fn should_support_wildcards() {
494 let origin1 = Some("http://parity.io".into());
496 let origin2 = Some("http://parity.iot".into());
497 let origin3 = Some("chrome-extension://test".into());
498 let host = None;
499 let allowed = Some(vec![
500 AccessControlAllowOrigin::Value("http://*.io".into()),
501 AccessControlAllowOrigin::Value("chrome-extension://*".into())
502 ]);
503
504 let res1 = get_cors_allow_origin(origin1, host, &allowed);
506 let res2 = get_cors_allow_origin(origin2, host, &allowed);
507 let res3 = get_cors_allow_origin(origin3, host, &allowed);
508
509 assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
511 assert_eq!(res2, AllowCors::Invalid);
512 assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())));
513 }
514
515 #[test]
516 fn should_return_invalid_if_header_not_allowed() {
517 let cors_allow_headers = AccessControlAllowHeaders::Only(vec![
519 "x-allowed".to_owned(),
520 ]);
521 let headers = vec!["Access-Control-Request-Headers"];
522 let requested = vec!["x-not-allowed"];
523
524 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
526
527 assert_eq!(res, AllowCors::Invalid);
529 }
530
531 #[test]
532 fn should_return_valid_if_header_allowed() {
533 let allowed = vec![
535 "x-allowed".to_owned(),
536 ];
537 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
538 let headers = vec!["Access-Control-Request-Headers"];
539 let requested = vec!["x-allowed"];
540
541 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| (*x).to_owned());
543
544 let allowed = vec![
546 "x-allowed".to_owned(),
547 ];
548 assert_eq!(res, AllowCors::Ok(allowed));
549 }
550
551 #[test]
552 fn should_return_no_allowed_headers_if_none_in_request() {
553 let allowed = vec![
555 "x-allowed".to_owned(),
556 ];
557 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
558 let headers: Vec<String> = vec![];
559
560 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
562
563 assert_eq!(res, AllowCors::NotRequired);
565 }
566
567 #[test]
568 fn should_return_not_required_if_any_header_allowed() {
569 let cors_allow_headers = AccessControlAllowHeaders::Any;
571 let headers: Vec<String> = vec![];
572
573 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
575
576 assert_eq!(res, AllowCors::NotRequired);
578 }
579
580}