1use hyper::{header, Response, Body};
4use std::collections::HashSet;
5use std::time::Duration;
6
7#[derive(Clone, Default)]
8pub struct Cors {
9 pub allow_origins: Option<HashSet<String>>,
10 pub allow_methods: HashSet<String>,
11 pub allow_headers: HashSet<String>,
12 pub expose_headers: HashSet<String>,
13 pub allow_credentials: bool,
14 pub max_age: Option<Duration>,
15}
16
17impl Cors {
18 pub fn new() -> Self {
19 Self::default()
20 }
21
22 pub fn allow_origins<T, I>(mut self, origins: I) -> Self
23 where
24 T: AsRef<str> + Into<String>,
25 I: IntoIterator<Item = T>,
26 {
27 self.allow_origins = Some(origins.into_iter().map(|s| s.into()).collect());
28 self
29 }
30
31 pub fn allow_methods<T, I>(mut self, methods: I) -> Self
32 where
33 T: AsRef<str> + Into<String>,
34 I: IntoIterator<Item = T>,
35 {
36 self.allow_methods = methods.into_iter().map(|s| s.into()).collect();
37 self
38 }
39
40 pub fn allow_headers<T, I>(mut self, headers: I) -> Self
41 where
42 T: AsRef<str> + Into<String>,
43 I: IntoIterator<Item = T>,
44 {
45 self.allow_headers = headers.into_iter().map(|s| s.into()).collect();
46 self
47 }
48
49 pub fn allow_credentials(mut self, allow: bool) -> Self {
50 self.allow_credentials = allow;
51 self
52 }
53
54 pub fn max_age<T: Into<Duration>>(mut self, duration: T) -> Self {
55 self.max_age = Some(duration.into());
56 self
57 }
58
59 pub fn handle_preflight(&self, origin: Option<&str>) -> Response<Body> { let mut response = Response::new(Body::empty());
62 let headers = response.headers_mut();
63
64 match &self.allow_origins {
66 Some(origins) => {
67 if let Some(origin) = origin {
68 if origins.contains(origin) {
69 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
70 }
71 }
72 }
73 None => {
74 if let Some(origin) = origin {
75 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
76 } else {
77 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
78 }
79 }
80 }
81
82 if !self.allow_methods.is_empty() {
84 let methods = self.allow_methods.iter()
85 .map(|s| s.as_str())
86 .collect::<Vec<_>>()
87 .join(", ");
88 headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, methods.parse().unwrap());
89 }
90
91 if !self.allow_headers.is_empty() {
93 let headers_str = self.allow_headers.iter()
94 .map(|s| s.as_str())
95 .collect::<Vec<_>>()
96 .join(", ");
97 headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, headers_str.parse().unwrap());
98 }
99
100 if let Some(max_age) = self.max_age {
102 headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.as_secs().to_string().parse().unwrap());
103 }
104
105 if self.allow_credentials {
107 headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
108 }
109
110 response
111 }
112
113 pub fn apply_headers(&self, response: &mut Response<Body>, origin: Option<&str>) { let headers = response.headers_mut();
116
117 match &self.allow_origins {
119 Some(origins) => {
120 if let Some(origin) = origin {
121 if origins.contains(origin) {
122 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
123 }
124 }
125 }
126 None => {
127 if let Some(origin) = origin {
128 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
129 } else {
130 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
131 }
132 }
133 }
134
135 if !self.expose_headers.is_empty() {
137 let expose = self.expose_headers.iter()
138 .map(|s| s.as_str())
139 .collect::<Vec<_>>()
140 .join(", ");
141 headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.parse().unwrap());
142 }
143
144 if self.allow_credentials {
146 headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
147 }
148 }
149}