rs_jsonrpc_server_utils/
cors.rs

1//! CORS handling utility functions
2
3use std::{fmt, ops};
4use hosts::{Host, Port};
5use matcher::{Matcher, Pattern};
6
7/// Origin Protocol
8#[derive(Clone, Hash, Debug, PartialEq, Eq)]
9pub enum OriginProtocol {
10	/// Http protocol
11	Http,
12	/// Https protocol
13	Https,
14	/// Custom protocol
15	Custom(String),
16}
17
18/// Request Origin
19#[derive(Clone, PartialEq, Eq, Debug, Hash)]
20pub struct Origin {
21	protocol: OriginProtocol,
22	host: Host,
23	as_string: String,
24	matcher: Matcher,
25}
26
27impl<T: AsRef<str>> From<T> for Origin {
28	fn from(string: T) -> Self {
29		Origin::parse(string.as_ref())
30	}
31}
32
33impl Origin {
34	fn with_host(protocol: OriginProtocol, host: Host) -> Self {
35		let string = Self::to_string(&protocol, &host);
36		let matcher = Matcher::new(&string);
37
38		Origin {
39			protocol: protocol,
40			host: host,
41			as_string: string,
42			matcher: matcher,
43		}
44	}
45
46	/// Creates new origin given protocol, hostname and port parts.
47	/// Pre-processes input data if necessary.
48	pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
49		Self::with_host(protocol, Host::new(host, port))
50	}
51
52	/// Attempts to parse given string as a `Origin`.
53	/// NOTE: This method always succeeds and falls back to sensible defaults.
54	pub fn parse(data: &str) -> Self {
55		let mut it = data.split("://");
56		let proto = it.next().expect("split always returns non-empty iterator.");
57		let hostname = it.next();
58
59		let (proto, hostname) = match hostname {
60			None => (None, proto),
61			Some(hostname) => (Some(proto), hostname),
62		};
63
64		let proto = proto.map(str::to_lowercase);
65		let hostname = Host::parse(hostname);
66
67		let protocol = match proto {
68			None => OriginProtocol::Http,
69			Some(ref p) if p == "http" => OriginProtocol::Http,
70			Some(ref p) if p == "https" => OriginProtocol::Https,
71			Some(other) => OriginProtocol::Custom(other),
72		};
73
74		Origin::with_host(protocol, hostname)
75	}
76
77	fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
78		format!(
79			"{}://{}",
80			match *protocol {
81				OriginProtocol::Http => "http",
82				OriginProtocol::Https => "https",
83				OriginProtocol::Custom(ref protocol) => protocol,
84			},
85			&**host,
86		)
87	}
88}
89
90impl Pattern for Origin {
91	fn matches<T: AsRef<str>>(&self, other: T) -> bool {
92		self.matcher.matches(other)
93	}
94}
95
96impl ops::Deref for Origin {
97	type Target = str;
98	fn deref(&self) -> &Self::Target {
99		&self.as_string
100	}
101}
102
103/// Origins allowed to access
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub enum AccessControlAllowOrigin {
106	/// Specific hostname
107	Value(Origin),
108	/// null-origin (file:///, sandboxed iframe)
109	Null,
110	/// Any non-null origin
111	Any,
112}
113
114impl fmt::Display for AccessControlAllowOrigin {
115	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116		write!(f, "{}", match *self {
117			AccessControlAllowOrigin::Any => "*",
118			AccessControlAllowOrigin::Null => "null",
119			AccessControlAllowOrigin::Value(ref val) => val,
120		})
121	}
122}
123
124impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
125	fn from(s: T) -> AccessControlAllowOrigin {
126		match s.into().as_str() {
127			"all" | "*" | "any" => AccessControlAllowOrigin::Any,
128			"null" => AccessControlAllowOrigin::Null,
129			origin => AccessControlAllowOrigin::Value(origin.into()),
130		}
131	}
132}
133
134/// CORS Header Result.
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum CorsHeader<T = AccessControlAllowOrigin> {
137	/// CORS header was not required. Origin is not present in the request.
138	NotRequired,
139	/// CORS header is not returned, Origin is not allowed to access the resource.
140	Invalid,
141	/// CORS header to include in the response. Origin is allowed to access the resource.
142	Ok(T),
143}
144
145impl<T> CorsHeader<T> {
146	/// Maps `Ok` variant of `CorsHeader`.
147	pub fn map<F, O>(self, f: F) -> CorsHeader<O> where
148		F: FnOnce(T) -> O,
149	{
150		use self::CorsHeader::*;
151
152		match self {
153			NotRequired => NotRequired,
154			Invalid => Invalid,
155			Ok(val) => Ok(f(val)),
156		}
157	}
158}
159
160impl<T> Into<Option<T>> for CorsHeader<T> {
161	fn into(self) -> Option<T> {
162		use self::CorsHeader::*;
163
164		match self {
165			NotRequired | Invalid => None,
166			Ok(header) => Some(header),
167		}
168	}
169}
170
171/// Returns correct CORS header (if any) given list of allowed origins and current origin.
172pub fn get_cors_header(origin: Option<&str>, host: Option<&str>, allowed: &Option<Vec<AccessControlAllowOrigin>>) -> CorsHeader {
173	match origin {
174		None => CorsHeader::NotRequired,
175		Some(ref origin) => {
176			if let Some(host) = host {
177				// Request initiated from the same server.
178				if origin.ends_with(host) {
179					// Additional check
180					let origin = Origin::parse(origin);
181					if &*origin.host == host {
182						return CorsHeader::NotRequired;
183					}
184				}
185			}
186
187			match allowed.as_ref() {
188				None => CorsHeader::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
189				Some(ref allowed) if *origin == "null" => {
190					allowed.iter().find(|cors| **cors == AccessControlAllowOrigin::Null).cloned()
191						.map(CorsHeader::Ok)
192						.unwrap_or(CorsHeader::Invalid)
193				},
194				Some(ref allowed) => {
195					allowed.iter().find(|cors| {
196						match **cors {
197							AccessControlAllowOrigin::Any => true,
198							AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
199							_ => false
200						}
201					})
202					.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
203					.map(CorsHeader::Ok).unwrap_or(CorsHeader::Invalid)
204				},
205			}
206		},
207	}
208}
209
210
211#[cfg(test)]
212mod tests {
213	use hosts::Host;
214	use super::{get_cors_header, CorsHeader, AccessControlAllowOrigin, Origin, OriginProtocol};
215
216	#[test]
217	fn should_parse_origin() {
218		use self::OriginProtocol::*;
219
220		assert_eq!(Origin::parse("http://superstring.ch"), Origin::new(Http, "superstring.ch", None));
221		assert_eq!(Origin::parse("http://superstring.ch:8443"), Origin::new(Https, "superstring.ch", Some(8443)));
222		assert_eq!(Origin::parse("chrome-extension://124.0.0.1"), Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None));
223		assert_eq!(Origin::parse("superstring.ch/somepath"), Origin::new(Http, "superstring.ch", None));
224		assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545)));
225	}
226
227	#[test]
228	fn should_not_allow_partially_matching_origin() {
229		// given
230		let origin1 = Origin::parse("http://subdomain.somedomain.io");
231		let origin2 = Origin::parse("http://somedomain.io:8080");
232		let host = Host::parse("http://somedomain.io");
233
234		let origin1 = Some(&*origin1);
235		let origin2 = Some(&*origin2);
236		let host = Some(&*host);
237
238		// when
239		let res1 = get_cors_header(origin1, host, &Some(vec![]));
240		let res2 = get_cors_header(origin2, host, &Some(vec![]));
241
242		// then
243		assert_eq!(res1, CorsHeader::Invalid);
244		assert_eq!(res2, CorsHeader::Invalid);
245	}
246
247	#[test]
248	fn should_allow_origins_that_matches_hosts() {
249		// given
250		let origin = Origin::parse("http://127.0.0.1:8080");
251		let host = Host::parse("http://127.0.0.1:8080");
252
253		let origin = Some(&*origin);
254		let host = Some(&*host);
255
256		// when
257		let res = get_cors_header(origin, host, &None);
258
259		// then
260		assert_eq!(res, CorsHeader::NotRequired);
261	}
262
263	#[test]
264	fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
265		// given
266		let origin = None;
267		let host = None;
268
269		// when
270		let res = get_cors_header(origin, host, &None);
271
272		// then
273		assert_eq!(res, CorsHeader::NotRequired);
274	}
275
276	#[test]
277	fn should_return_domain_when_all_are_allowed() {
278		// given
279		let origin = Some("superstring.ch");
280		let host = None;
281
282		// when
283		let res = get_cors_header(origin, host, &None);
284
285		// then
286		assert_eq!(res, CorsHeader::Ok("superstring.ch".into()));
287	}
288
289	#[test]
290	fn should_return_none_for_empty_origin() {
291		// given
292		let origin = None;
293		let host = None;
294
295		// when
296		let res = get_cors_header(
297			origin,
298			host,
299			&Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into())]),
300		);
301
302		// then
303		assert_eq!(res, CorsHeader::NotRequired);
304	}
305
306	#[test]
307	fn should_return_none_for_empty_list() {
308		// given
309		let origin = None;
310		let host = None;
311
312		// when
313		let res = get_cors_header(origin, host, &Some(Vec::new()));
314
315		// then
316		assert_eq!(res, CorsHeader::NotRequired);
317	}
318
319	#[test]
320	fn should_return_none_for_not_matching_origin() {
321		// given
322		let origin = Some("http://superstring.ch".into());
323		let host = None;
324
325		// when
326		let res = get_cors_header(
327			origin,
328			host,
329			&Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into())]),
330		);
331
332		// then
333		assert_eq!(res, CorsHeader::Invalid);
334	}
335
336	#[test]
337	fn should_return_specific_origin_if_we_allow_any() {
338		// given
339		let origin = Some("http://superstring.ch".into());
340		let host = None;
341
342		// when
343		let res = get_cors_header(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
344
345		// then
346		assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
347	}
348
349	#[test]
350	fn should_return_none_if_origin_is_not_defined() {
351		// given
352		let origin = None;
353		let host = None;
354
355		// when
356		let res = get_cors_header(
357			origin,
358			host,
359			&Some(vec![AccessControlAllowOrigin::Null]),
360		);
361
362		// then
363		assert_eq!(res, CorsHeader::NotRequired);
364	}
365
366	#[test]
367	fn should_return_null_if_origin_is_null() {
368		// given
369		let origin = Some("null".into());
370		let host = None;
371
372		// when
373		let res = get_cors_header(
374			origin,
375			host,
376			&Some(vec![AccessControlAllowOrigin::Null]),
377		);
378
379		// then
380		assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Null));
381	}
382
383	#[test]
384	fn should_return_specific_origin_if_there_is_a_match() {
385		// given
386		let origin = Some("http://superstring.ch".into());
387		let host = None;
388
389		// when
390		let res = get_cors_header(
391			origin,
392			host,
393			&Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into()), AccessControlAllowOrigin::Value("http://superstring.ch".into())]),
394		);
395
396		// then
397		assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
398	}
399
400	#[test]
401	fn should_support_wildcards() {
402		// given
403		let origin1 = Some("http://superstring.ch".into());
404		let origin2 = Some("http://superstring.cht".into());
405		let origin3 = Some("chrome-extension://test".into());
406		let host = None;
407		let allowed = Some(vec![
408		   AccessControlAllowOrigin::Value("http://*.io".into()),
409		   AccessControlAllowOrigin::Value("chrome-extension://*".into())
410		]);
411
412		// when
413		let res1 = get_cors_header(origin1, host, &allowed);
414		let res2 = get_cors_header(origin2, host, &allowed);
415		let res3 = get_cors_header(origin3, host, &allowed);
416
417		// then
418		assert_eq!(res1, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
419		assert_eq!(res2, CorsHeader::Invalid);
420		assert_eq!(res3, CorsHeader::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())));
421	}
422}