bitconch_jsonrpc_server_utils/
cors.rs

1//! CORS handling utility functions
2extern crate unicase;
3
4use std::{fmt, ops};
5use hosts::{Host, Port};
6use matcher::{Matcher, Pattern};
7use std::collections::HashSet;
8pub use self::unicase::Ascii;
9
10/// Origin Protocol
11#[derive(Clone, Hash, Debug, PartialEq, Eq)]
12pub enum OriginProtocol {
13	/// Http protocol
14	Http,
15	/// Https protocol
16	Https,
17	/// Custom protocol
18	Custom(String),
19}
20
21/// Request Origin
22#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Origin {
24	protocol: OriginProtocol,
25	host: Host,
26	as_string: String,
27	matcher: Matcher,
28}
29
30impl<T: AsRef<str>> From<T> for Origin {
31	fn from(string: T) -> Self {
32		Origin::parse(string.as_ref())
33	}
34}
35
36impl Origin {
37	fn with_host(protocol: OriginProtocol, host: Host) -> Self {
38		let string = Self::to_string(&protocol, &host);
39		let matcher = Matcher::new(&string);
40
41		Origin {
42			protocol: protocol,
43			host: host,
44			as_string: string,
45			matcher: matcher,
46		}
47	}
48
49	/// Creates new origin given protocol, hostname and port parts.
50	/// Pre-processes input data if necessary.
51	pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
52		Self::with_host(protocol, Host::new(host, port))
53	}
54
55	/// Attempts to parse given string as a `Origin`.
56	/// NOTE: This method always succeeds and falls back to sensible defaults.
57	pub fn parse(data: &str) -> Self {
58		let mut it = data.split("://");
59		let proto = it.next().expect("split always returns non-empty iterator.");
60		let hostname = it.next();
61
62		let (proto, hostname) = match hostname {
63			None => (None, proto),
64			Some(hostname) => (Some(proto), hostname),
65		};
66
67		let proto = proto.map(str::to_lowercase);
68		let hostname = Host::parse(hostname);
69
70		let protocol = match proto {
71			None => OriginProtocol::Http,
72			Some(ref p) if p == "http" => OriginProtocol::Http,
73			Some(ref p) if p == "https" => OriginProtocol::Https,
74			Some(other) => OriginProtocol::Custom(other),
75		};
76
77		Origin::with_host(protocol, hostname)
78	}
79
80	fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
81		format!(
82			"{}://{}",
83			match *protocol {
84				OriginProtocol::Http => "http",
85				OriginProtocol::Https => "https",
86				OriginProtocol::Custom(ref protocol) => protocol,
87			},
88			&**host,
89		)
90	}
91}
92
93impl Pattern for Origin {
94	fn matches<T: AsRef<str>>(&self, other: T) -> bool {
95		self.matcher.matches(other)
96	}
97}
98
99impl ops::Deref for Origin {
100	type Target = str;
101	fn deref(&self) -> &Self::Target {
102		&self.as_string
103	}
104}
105
106/// Origins allowed to access
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AccessControlAllowOrigin {
109	/// Specific hostname
110	Value(Origin),
111	/// null-origin (file:///, sandboxed iframe)
112	Null,
113	/// Any non-null origin
114	Any,
115}
116
117impl fmt::Display for AccessControlAllowOrigin {
118	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119		write!(f, "{}", match *self {
120			AccessControlAllowOrigin::Any => "*",
121			AccessControlAllowOrigin::Null => "null",
122			AccessControlAllowOrigin::Value(ref val) => val,
123		})
124	}
125}
126
127impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
128	fn from(s: T) -> AccessControlAllowOrigin {
129		match s.into().as_str() {
130			"all" | "*" | "any" => AccessControlAllowOrigin::Any,
131			"null" => AccessControlAllowOrigin::Null,
132			origin => AccessControlAllowOrigin::Value(origin.into()),
133		}
134	}
135}
136
137/// Headers allowed to access
138#[derive(Debug, Clone, PartialEq)]
139pub enum AccessControlAllowHeaders {
140	/// Specific headers
141	Only(Vec<String>),
142	/// Any header
143	Any,
144}
145
146/// CORS response headers
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub enum AllowCors<T> {
149	/// CORS header was not required. Origin is not present in the request.
150	NotRequired,
151	/// CORS header is not returned, Origin is not allowed to access the resource.
152	Invalid,
153	/// CORS header to include in the response. Origin is allowed to access the resource.
154	Ok(T),
155}
156
157impl<T> AllowCors<T> {
158	/// Maps `Ok` variant of `AllowCors`.
159	pub fn map<F, O>(self, f: F) -> AllowCors<O> where
160		F: FnOnce(T) -> O,
161	{
162		use self::AllowCors::*;
163
164		match self {
165			NotRequired => NotRequired,
166			Invalid => Invalid,
167			Ok(val) => Ok(f(val)),
168		}
169	}
170}
171
172impl<T> Into<Option<T>> for AllowCors<T> {
173	fn into(self) -> Option<T> {
174		use self::AllowCors::*;
175
176		match self {
177			NotRequired | Invalid => None,
178			Ok(header) => Some(header),
179		}
180	}
181}
182
183/// Returns correct CORS header (if any) given list of allowed origins and current origin.
184pub fn get_cors_allow_origin(
185	origin: Option<&str>,
186	host: Option<&str>,
187	allowed: &Option<Vec<AccessControlAllowOrigin>>
188) -> AllowCors<AccessControlAllowOrigin> {
189	match origin {
190		None => AllowCors::NotRequired,
191		Some(ref origin) => {
192			if let Some(host) = host {
193				// Request initiated from the same server.
194				if origin.ends_with(host) {
195					// Additional check
196					let origin = Origin::parse(origin);
197					if &*origin.host == host {
198						return AllowCors::NotRequired;
199					}
200				}
201			}
202
203			match allowed.as_ref() {
204				None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
205				None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
206				Some(ref allowed) if *origin == "null" => {
207					allowed.iter().find(|cors| **cors == AccessControlAllowOrigin::Null).cloned()
208						.map(AllowCors::Ok)
209						.unwrap_or(AllowCors::Invalid)
210				},
211				Some(ref allowed) => {
212					allowed.iter().find(|cors| {
213						match **cors {
214							AccessControlAllowOrigin::Any => true,
215							AccessControlAllowOrigin::Value(ref val) if val.matches(origin) =>
216							{
217								true
218							},
219							_ => false
220						}
221					})
222					.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
223					.map(AllowCors::Ok).unwrap_or(AllowCors::Invalid)
224				},
225			}
226		},
227	}
228}
229
230/// Validates if the `AccessControlAllowedHeaders` in the request are allowed.
231pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
232	mut headers: impl Iterator<Item=T>,
233	requested_headers: impl Iterator<Item=T>,
234	cors_allow_headers: &AccessControlAllowHeaders,
235	to_result: F
236) -> AllowCors<Vec<O>> {
237	// Check if the header fields which were sent in the request are allowed
238	if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
239		let are_all_allowed = headers
240			.all(|header| {
241				let name = &Ascii::new(header.as_ref());
242				only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
243			});
244
245		if !are_all_allowed {
246			return AllowCors::Invalid;
247		}
248	}
249
250	// Check if `AccessControlRequestHeaders` contains fields which were allowed
251	let (filtered, headers) = match cors_allow_headers {
252		AccessControlAllowHeaders::Any => {
253			let headers = requested_headers.map(to_result).collect();
254			(false, headers)
255		},
256		AccessControlAllowHeaders::Only(only) => {
257			let mut filtered = false;
258			let headers: Vec<_> = requested_headers
259				.filter(|header| {
260					let name = &Ascii::new(header.as_ref());
261					filtered = true;
262					only.iter().any(|h| &Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
263				})
264				.map(to_result)
265				.collect();
266
267			(filtered, headers)
268		},
269	};
270
271	if headers.is_empty() {
272		if filtered {
273			AllowCors::Invalid
274		} else {
275			AllowCors::NotRequired
276		}
277	} else {
278		AllowCors::Ok(headers)
279	}
280}
281
282/// Returns headers which are always allowed.
283lazy_static! {
284	static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
285		let mut hs = HashSet::new();
286		hs.insert(Ascii::new("Accept"));
287		hs.insert(Ascii::new("Accept-Language"));
288		hs.insert(Ascii::new("Access-Control-Allow-Origin"));
289		hs.insert(Ascii::new("Access-Control-Request-Headers"));
290		hs.insert(Ascii::new("Content-Language"));
291		hs.insert(Ascii::new("Content-Type"));
292		hs.insert(Ascii::new("Host"));
293		hs.insert(Ascii::new("Origin"));
294		hs.insert(Ascii::new("Content-Length"));
295		hs.insert(Ascii::new("Connection"));
296		hs.insert(Ascii::new("User-Agent"));
297		hs
298	};
299}
300
301#[cfg(test)]
302mod tests {
303	use std::iter;
304
305	use super::*;
306	use hosts::Host;
307
308	#[test]
309	fn should_parse_origin() {
310		use self::OriginProtocol::*;
311
312		assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
313		assert_eq!(Origin::parse("https://parity.io:8443"), Origin::new(Https, "parity.io", Some(8443)));
314		assert_eq!(Origin::parse("chrome-extension://124.0.0.1"), Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None));
315		assert_eq!(Origin::parse("parity.io/somepath"), Origin::new(Http, "parity.io", None));
316		assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545)));
317	}
318
319	#[test]
320	fn should_not_allow_partially_matching_origin() {
321		// given
322		let origin1 = Origin::parse("http://subdomain.somedomain.io");
323		let origin2 = Origin::parse("http://somedomain.io:8080");
324		let host = Host::parse("http://somedomain.io");
325
326		let origin1 = Some(&*origin1);
327		let origin2 = Some(&*origin2);
328		let host = Some(&*host);
329
330		// when
331		let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
332		let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
333
334		// then
335		assert_eq!(res1, AllowCors::Invalid);
336		assert_eq!(res2, AllowCors::Invalid);
337	}
338
339	#[test]
340	fn should_allow_origins_that_matches_hosts() {
341		// given
342		let origin = Origin::parse("http://127.0.0.1:8080");
343		let host = Host::parse("http://127.0.0.1:8080");
344
345		let origin = Some(&*origin);
346		let host = Some(&*host);
347
348		// when
349		let res = get_cors_allow_origin(origin, host, &None);
350
351		// then
352		assert_eq!(res, AllowCors::NotRequired);
353	}
354
355	#[test]
356	fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
357		// given
358		let origin = None;
359		let host = None;
360
361		// when
362		let res = get_cors_allow_origin(origin, host, &None);
363
364		// then
365		assert_eq!(res, AllowCors::NotRequired);
366	}
367
368	#[test]
369	fn should_return_domain_when_all_are_allowed() {
370		// given
371		let origin = Some("parity.io");
372		let host = None;
373
374		// when
375		let res = get_cors_allow_origin(origin, host, &None);
376
377		// then
378		assert_eq!(res, AllowCors::Ok("parity.io".into()));
379	}
380
381	#[test]
382	fn should_return_none_for_empty_origin() {
383		// given
384		let origin = None;
385		let host = None;
386
387		// when
388		let res = get_cors_allow_origin(
389			origin,
390			host,
391			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
392		);
393
394		// then
395		assert_eq!(res, AllowCors::NotRequired);
396	}
397
398	#[test]
399	fn should_return_none_for_empty_list() {
400		// given
401		let origin = None;
402		let host = None;
403
404		// when
405		let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
406
407		// then
408		assert_eq!(res, AllowCors::NotRequired);
409	}
410
411	#[test]
412	fn should_return_none_for_not_matching_origin() {
413		// given
414		let origin = Some("http://parity.io".into());
415		let host = None;
416
417		// when
418		let res = get_cors_allow_origin(
419			origin,
420			host,
421			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
422		);
423
424		// then
425		assert_eq!(res, AllowCors::Invalid);
426	}
427
428	#[test]
429	fn should_return_specific_origin_if_we_allow_any() {
430		// given
431		let origin = Some("http://parity.io".into());
432		let host = None;
433
434		// when
435		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
436
437		// then
438		assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
439	}
440
441	#[test]
442	fn should_return_none_if_origin_is_not_defined() {
443		// given
444		let origin = None;
445		let host = None;
446
447		// when
448		let res = get_cors_allow_origin(
449			origin,
450			host,
451			&Some(vec![AccessControlAllowOrigin::Null]),
452		);
453
454		// then
455		assert_eq!(res, AllowCors::NotRequired);
456	}
457
458	#[test]
459	fn should_return_null_if_origin_is_null() {
460		// given
461		let origin = Some("null".into());
462		let host = None;
463
464		// when
465		let res = get_cors_allow_origin(
466			origin,
467			host,
468			&Some(vec![AccessControlAllowOrigin::Null]),
469		);
470
471		// then
472		assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
473	}
474
475	#[test]
476	fn should_return_specific_origin_if_there_is_a_match() {
477		// given
478		let origin = Some("http://parity.io".into());
479		let host = None;
480
481		// when
482		let res = get_cors_allow_origin(
483			origin,
484			host,
485			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into()), AccessControlAllowOrigin::Value("http://parity.io".into())]),
486		);
487
488		// then
489		assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
490	}
491
492	#[test]
493	fn should_support_wildcards() {
494		// given
495		let origin1 = Some("http://parity.io".into());
496		let origin2 = Some("http://parity.iot".into());
497		let origin3 = Some("chrome-extension://test".into());
498		let host = None;
499		let allowed = Some(vec![
500		   AccessControlAllowOrigin::Value("http://*.io".into()),
501		   AccessControlAllowOrigin::Value("chrome-extension://*".into())
502		]);
503
504		// when
505		let res1 = get_cors_allow_origin(origin1, host, &allowed);
506		let res2 = get_cors_allow_origin(origin2, host, &allowed);
507		let res3 = get_cors_allow_origin(origin3, host, &allowed);
508
509		// then
510		assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
511		assert_eq!(res2, AllowCors::Invalid);
512		assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())));
513	}
514
515	#[test]
516	fn should_return_invalid_if_header_not_allowed() {
517		// given
518		let cors_allow_headers = AccessControlAllowHeaders::Only(vec![
519			"x-allowed".to_owned(),
520		]);
521		let headers = vec!["Access-Control-Request-Headers"];
522		let requested = vec!["x-not-allowed"];
523
524		// when
525		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
526
527		// then
528		assert_eq!(res, AllowCors::Invalid);
529	}
530
531	#[test]
532	fn should_return_valid_if_header_allowed() {
533		// given
534		let allowed = vec![
535			"x-allowed".to_owned(),
536		];
537		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
538		let headers = vec!["Access-Control-Request-Headers"];
539		let requested = vec!["x-allowed"];
540
541		// when
542		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| (*x).to_owned());
543
544		// then
545		let allowed = vec![
546			"x-allowed".to_owned(),
547		];
548		assert_eq!(res, AllowCors::Ok(allowed));
549	}
550
551	#[test]
552	fn should_return_no_allowed_headers_if_none_in_request() {
553		// given
554		let allowed = vec![
555			"x-allowed".to_owned(),
556		];
557		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
558		let headers: Vec<String> = vec![];
559
560		// when
561		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
562
563		// then
564		assert_eq!(res, AllowCors::NotRequired);
565	}
566
567	#[test]
568	fn should_return_not_required_if_any_header_allowed() {
569		// given
570		let cors_allow_headers = AccessControlAllowHeaders::Any;
571		let headers: Vec<String> = vec![];
572
573		// when
574		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
575
576		// then
577		assert_eq!(res, AllowCors::NotRequired);
578	}
579
580}