1use std::{fmt, ops};
4use hosts::{Host, Port};
5use matcher::{Matcher, Pattern};
6
7#[derive(Clone, Hash, Debug, PartialEq, Eq)]
9pub enum OriginProtocol {
10 Http,
12 Https,
14 Custom(String),
16}
17
18#[derive(Clone, PartialEq, Eq, Debug, Hash)]
20pub struct Origin {
21 protocol: OriginProtocol,
22 host: Host,
23 as_string: String,
24 matcher: Matcher,
25}
26
27impl<T: AsRef<str>> From<T> for Origin {
28 fn from(string: T) -> Self {
29 Origin::parse(string.as_ref())
30 }
31}
32
33impl Origin {
34 fn with_host(protocol: OriginProtocol, host: Host) -> Self {
35 let string = Self::to_string(&protocol, &host);
36 let matcher = Matcher::new(&string);
37
38 Origin {
39 protocol: protocol,
40 host: host,
41 as_string: string,
42 matcher: matcher,
43 }
44 }
45
46 pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
49 Self::with_host(protocol, Host::new(host, port))
50 }
51
52 pub fn parse(data: &str) -> Self {
55 let mut it = data.split("://");
56 let proto = it.next().expect("split always returns non-empty iterator.");
57 let hostname = it.next();
58
59 let (proto, hostname) = match hostname {
60 None => (None, proto),
61 Some(hostname) => (Some(proto), hostname),
62 };
63
64 let proto = proto.map(str::to_lowercase);
65 let hostname = Host::parse(hostname);
66
67 let protocol = match proto {
68 None => OriginProtocol::Http,
69 Some(ref p) if p == "http" => OriginProtocol::Http,
70 Some(ref p) if p == "https" => OriginProtocol::Https,
71 Some(other) => OriginProtocol::Custom(other),
72 };
73
74 Origin::with_host(protocol, hostname)
75 }
76
77 fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
78 format!(
79 "{}://{}",
80 match *protocol {
81 OriginProtocol::Http => "http",
82 OriginProtocol::Https => "https",
83 OriginProtocol::Custom(ref protocol) => protocol,
84 },
85 &**host,
86 )
87 }
88}
89
90impl Pattern for Origin {
91 fn matches<T: AsRef<str>>(&self, other: T) -> bool {
92 self.matcher.matches(other)
93 }
94}
95
96impl ops::Deref for Origin {
97 type Target = str;
98 fn deref(&self) -> &Self::Target {
99 &self.as_string
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
105pub enum AccessControlAllowOrigin {
106 Value(Origin),
108 Null,
110 Any,
112}
113
114impl fmt::Display for AccessControlAllowOrigin {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 write!(f, "{}", match *self {
117 AccessControlAllowOrigin::Any => "*",
118 AccessControlAllowOrigin::Null => "null",
119 AccessControlAllowOrigin::Value(ref val) => val,
120 })
121 }
122}
123
124impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
125 fn from(s: T) -> AccessControlAllowOrigin {
126 match s.into().as_str() {
127 "all" | "*" | "any" => AccessControlAllowOrigin::Any,
128 "null" => AccessControlAllowOrigin::Null,
129 origin => AccessControlAllowOrigin::Value(origin.into()),
130 }
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum CorsHeader<T = AccessControlAllowOrigin> {
137 NotRequired,
139 Invalid,
141 Ok(T),
143}
144
145impl<T> CorsHeader<T> {
146 pub fn map<F, O>(self, f: F) -> CorsHeader<O> where
148 F: FnOnce(T) -> O,
149 {
150 use self::CorsHeader::*;
151
152 match self {
153 NotRequired => NotRequired,
154 Invalid => Invalid,
155 Ok(val) => Ok(f(val)),
156 }
157 }
158}
159
160impl<T> Into<Option<T>> for CorsHeader<T> {
161 fn into(self) -> Option<T> {
162 use self::CorsHeader::*;
163
164 match self {
165 NotRequired | Invalid => None,
166 Ok(header) => Some(header),
167 }
168 }
169}
170
171pub fn get_cors_header(origin: Option<&str>, host: Option<&str>, allowed: &Option<Vec<AccessControlAllowOrigin>>) -> CorsHeader {
173 match origin {
174 None => CorsHeader::NotRequired,
175 Some(ref origin) => {
176 if let Some(host) = host {
177 if origin.ends_with(host) {
179 let origin = Origin::parse(origin);
181 if &*origin.host == host {
182 return CorsHeader::NotRequired;
183 }
184 }
185 }
186
187 match allowed.as_ref() {
188 None => CorsHeader::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
189 Some(ref allowed) if *origin == "null" => {
190 allowed.iter().find(|cors| **cors == AccessControlAllowOrigin::Null).cloned()
191 .map(CorsHeader::Ok)
192 .unwrap_or(CorsHeader::Invalid)
193 },
194 Some(ref allowed) => {
195 allowed.iter().find(|cors| {
196 match **cors {
197 AccessControlAllowOrigin::Any => true,
198 AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
199 _ => false
200 }
201 })
202 .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
203 .map(CorsHeader::Ok).unwrap_or(CorsHeader::Invalid)
204 },
205 }
206 },
207 }
208}
209
210
211#[cfg(test)]
212mod tests {
213 use hosts::Host;
214 use super::{get_cors_header, CorsHeader, AccessControlAllowOrigin, Origin, OriginProtocol};
215
216 #[test]
217 fn should_parse_origin() {
218 use self::OriginProtocol::*;
219
220 assert_eq!(Origin::parse("http://superstring.ch"), Origin::new(Http, "superstring.ch", None));
221 assert_eq!(Origin::parse("http://superstring.ch:8443"), Origin::new(Https, "superstring.ch", Some(8443)));
222 assert_eq!(Origin::parse("chrome-extension://124.0.0.1"), Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None));
223 assert_eq!(Origin::parse("superstring.ch/somepath"), Origin::new(Http, "superstring.ch", None));
224 assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545)));
225 }
226
227 #[test]
228 fn should_not_allow_partially_matching_origin() {
229 let origin1 = Origin::parse("http://subdomain.somedomain.io");
231 let origin2 = Origin::parse("http://somedomain.io:8080");
232 let host = Host::parse("http://somedomain.io");
233
234 let origin1 = Some(&*origin1);
235 let origin2 = Some(&*origin2);
236 let host = Some(&*host);
237
238 let res1 = get_cors_header(origin1, host, &Some(vec![]));
240 let res2 = get_cors_header(origin2, host, &Some(vec![]));
241
242 assert_eq!(res1, CorsHeader::Invalid);
244 assert_eq!(res2, CorsHeader::Invalid);
245 }
246
247 #[test]
248 fn should_allow_origins_that_matches_hosts() {
249 let origin = Origin::parse("http://127.0.0.1:8080");
251 let host = Host::parse("http://127.0.0.1:8080");
252
253 let origin = Some(&*origin);
254 let host = Some(&*host);
255
256 let res = get_cors_header(origin, host, &None);
258
259 assert_eq!(res, CorsHeader::NotRequired);
261 }
262
263 #[test]
264 fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
265 let origin = None;
267 let host = None;
268
269 let res = get_cors_header(origin, host, &None);
271
272 assert_eq!(res, CorsHeader::NotRequired);
274 }
275
276 #[test]
277 fn should_return_domain_when_all_are_allowed() {
278 let origin = Some("superstring.ch");
280 let host = None;
281
282 let res = get_cors_header(origin, host, &None);
284
285 assert_eq!(res, CorsHeader::Ok("superstring.ch".into()));
287 }
288
289 #[test]
290 fn should_return_none_for_empty_origin() {
291 let origin = None;
293 let host = None;
294
295 let res = get_cors_header(
297 origin,
298 host,
299 &Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into())]),
300 );
301
302 assert_eq!(res, CorsHeader::NotRequired);
304 }
305
306 #[test]
307 fn should_return_none_for_empty_list() {
308 let origin = None;
310 let host = None;
311
312 let res = get_cors_header(origin, host, &Some(Vec::new()));
314
315 assert_eq!(res, CorsHeader::NotRequired);
317 }
318
319 #[test]
320 fn should_return_none_for_not_matching_origin() {
321 let origin = Some("http://superstring.ch".into());
323 let host = None;
324
325 let res = get_cors_header(
327 origin,
328 host,
329 &Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into())]),
330 );
331
332 assert_eq!(res, CorsHeader::Invalid);
334 }
335
336 #[test]
337 fn should_return_specific_origin_if_we_allow_any() {
338 let origin = Some("http://superstring.ch".into());
340 let host = None;
341
342 let res = get_cors_header(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
344
345 assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
347 }
348
349 #[test]
350 fn should_return_none_if_origin_is_not_defined() {
351 let origin = None;
353 let host = None;
354
355 let res = get_cors_header(
357 origin,
358 host,
359 &Some(vec![AccessControlAllowOrigin::Null]),
360 );
361
362 assert_eq!(res, CorsHeader::NotRequired);
364 }
365
366 #[test]
367 fn should_return_null_if_origin_is_null() {
368 let origin = Some("null".into());
370 let host = None;
371
372 let res = get_cors_header(
374 origin,
375 host,
376 &Some(vec![AccessControlAllowOrigin::Null]),
377 );
378
379 assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Null));
381 }
382
383 #[test]
384 fn should_return_specific_origin_if_there_is_a_match() {
385 let origin = Some("http://superstring.ch".into());
387 let host = None;
388
389 let res = get_cors_header(
391 origin,
392 host,
393 &Some(vec![AccessControlAllowOrigin::Value("http://sophon.org".into()), AccessControlAllowOrigin::Value("http://superstring.ch".into())]),
394 );
395
396 assert_eq!(res, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
398 }
399
400 #[test]
401 fn should_support_wildcards() {
402 let origin1 = Some("http://superstring.ch".into());
404 let origin2 = Some("http://superstring.cht".into());
405 let origin3 = Some("chrome-extension://test".into());
406 let host = None;
407 let allowed = Some(vec![
408 AccessControlAllowOrigin::Value("http://*.io".into()),
409 AccessControlAllowOrigin::Value("chrome-extension://*".into())
410 ]);
411
412 let res1 = get_cors_header(origin1, host, &allowed);
414 let res2 = get_cors_header(origin2, host, &allowed);
415 let res3 = get_cors_header(origin3, host, &allowed);
416
417 assert_eq!(res1, CorsHeader::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into())));
419 assert_eq!(res2, CorsHeader::Invalid);
420 assert_eq!(res3, CorsHeader::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())));
421 }
422}