1use hyper::{Body, Response, header};
4use std::collections::HashSet;
5use std::time::Duration;
6
7#[derive(Clone, Default)]
8pub struct Cors {
9 pub allow_origins: Option<HashSet<String>>,
10 pub allow_methods: HashSet<String>,
11 pub allow_headers: HashSet<String>,
12 pub expose_headers: HashSet<String>,
13 pub allow_credentials: bool,
14 pub max_age: Option<Duration>,
15}
16
17impl Cors {
18 pub fn new() -> Self {
19 Self::default()
20 }
21
22 pub fn allow_origins<T, I>(mut self, origins: I) -> Self
23 where
24 T: AsRef<str> + Into<String>,
25 I: IntoIterator<Item = T>,
26 {
27 self.allow_origins = Some(origins.into_iter().map(|s| s.into()).collect());
28 self
29 }
30
31 pub fn allow_methods<T, I>(mut self, methods: I) -> Self
32 where
33 T: AsRef<str> + Into<String>,
34 I: IntoIterator<Item = T>,
35 {
36 self.allow_methods = methods.into_iter().map(|s| s.into()).collect();
37 self
38 }
39
40 pub fn allow_headers<T, I>(mut self, headers: I) -> Self
41 where
42 T: AsRef<str> + Into<String>,
43 I: IntoIterator<Item = T>,
44 {
45 self.allow_headers = headers.into_iter().map(|s| s.into()).collect();
46 self
47 }
48
49 pub fn allow_credentials(mut self, allow: bool) -> Self {
50 self.allow_credentials = allow;
51 self
52 }
53
54 pub fn max_age<T: Into<Duration>>(mut self, duration: T) -> Self {
55 self.max_age = Some(duration.into());
56 self
57 }
58
59 pub fn handle_preflight(&self, origin: Option<&str>) -> Response<Body> {
61 let mut response = Response::new(Body::empty());
63 let headers = response.headers_mut();
64
65 match &self.allow_origins {
67 Some(origins) => {
68 if let Some(origin) = origin {
69 if origins.contains(origin) {
70 headers
71 .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
72 }
73 }
74 }
75 None => {
76 if let Some(origin) = origin {
77 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
78 } else {
79 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
80 }
81 }
82 }
83
84 if !self.allow_methods.is_empty() {
86 let methods = self
87 .allow_methods
88 .iter()
89 .map(|s| s.as_str())
90 .collect::<Vec<_>>()
91 .join(", ");
92 headers.insert(
93 header::ACCESS_CONTROL_ALLOW_METHODS,
94 methods.parse().unwrap(),
95 );
96 }
97
98 if !self.allow_headers.is_empty() {
100 let headers_str = self
101 .allow_headers
102 .iter()
103 .map(|s| s.as_str())
104 .collect::<Vec<_>>()
105 .join(", ");
106 headers.insert(
107 header::ACCESS_CONTROL_ALLOW_HEADERS,
108 headers_str.parse().unwrap(),
109 );
110 }
111
112 if let Some(max_age) = self.max_age {
114 headers.insert(
115 header::ACCESS_CONTROL_MAX_AGE,
116 max_age.as_secs().to_string().parse().unwrap(),
117 );
118 }
119
120 if self.allow_credentials {
122 headers.insert(
123 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
124 "true".parse().unwrap(),
125 );
126 }
127
128 response
129 }
130
131 pub fn apply_headers(&self, response: &mut Response<Body>, origin: Option<&str>) {
133 let headers = response.headers_mut();
135
136 match &self.allow_origins {
138 Some(origins) => {
139 if let Some(origin) = origin {
140 if origins.contains(origin) {
141 headers
142 .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
143 }
144 }
145 }
146 None => {
147 if let Some(origin) = origin {
148 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
149 } else {
150 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
151 }
152 }
153 }
154
155 if !self.expose_headers.is_empty() {
157 let expose = self
158 .expose_headers
159 .iter()
160 .map(|s| s.as_str())
161 .collect::<Vec<_>>()
162 .join(", ");
163 headers.insert(
164 header::ACCESS_CONTROL_EXPOSE_HEADERS,
165 expose.parse().unwrap(),
166 );
167 }
168
169 if self.allow_credentials {
171 headers.insert(
172 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
173 "true".parse().unwrap(),
174 );
175 }
176 }
177}