Skip to main content

nimble_http/
cors.rs

1//! cors.rs
2
3use hyper::{Body, Response, header};
4use std::collections::HashSet;
5use std::time::Duration;
6
7#[derive(Clone, Default)]
8pub struct Cors {
9	pub allow_origins: Option<HashSet<String>>,
10	pub allow_methods: HashSet<String>,
11	pub allow_headers: HashSet<String>,
12	pub expose_headers: HashSet<String>,
13	pub allow_credentials: bool,
14	pub max_age: Option<Duration>,
15}
16
17impl Cors {
18	pub fn new() -> Self {
19		Self::default()
20	}
21
22	pub fn allow_origins<T, I>(mut self, origins: I) -> Self
23	where
24		T: AsRef<str> + Into<String>,
25		I: IntoIterator<Item = T>,
26	{
27		self.allow_origins = Some(origins.into_iter().map(|s| s.into()).collect());
28		self
29	}
30
31	pub fn allow_methods<T, I>(mut self, methods: I) -> Self
32	where
33		T: AsRef<str> + Into<String>,
34		I: IntoIterator<Item = T>,
35	{
36		self.allow_methods = methods.into_iter().map(|s| s.into()).collect();
37		self
38	}
39
40	pub fn allow_headers<T, I>(mut self, headers: I) -> Self
41	where
42		T: AsRef<str> + Into<String>,
43		I: IntoIterator<Item = T>,
44	{
45		self.allow_headers = headers.into_iter().map(|s| s.into()).collect();
46		self
47	}
48
49	pub fn allow_credentials(mut self, allow: bool) -> Self {
50		self.allow_credentials = allow;
51		self
52	}
53
54	pub fn max_age<T: Into<Duration>>(mut self, duration: T) -> Self {
55		self.max_age = Some(duration.into());
56		self
57	}
58
59	/// 处理 OPTIONS 预检请求
60	pub fn handle_preflight(&self, origin: Option<&str>) -> Response<Body> {
61		// 改这里
62		let mut response = Response::new(Body::empty());
63		let headers = response.headers_mut();
64
65		// 设置 Allow-Origin
66		match &self.allow_origins {
67			Some(origins) => {
68				if let Some(origin) = origin {
69					if origins.contains(origin) {
70						headers
71							.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
72					}
73				}
74			}
75			None => {
76				if let Some(origin) = origin {
77					headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
78				} else {
79					headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
80				}
81			}
82		}
83
84		// 设置 Allow-Methods
85		if !self.allow_methods.is_empty() {
86			let methods = self
87				.allow_methods
88				.iter()
89				.map(|s| s.as_str())
90				.collect::<Vec<_>>()
91				.join(", ");
92			headers.insert(
93				header::ACCESS_CONTROL_ALLOW_METHODS,
94				methods.parse().unwrap(),
95			);
96		}
97
98		// 设置 Allow-Headers
99		if !self.allow_headers.is_empty() {
100			let headers_str = self
101				.allow_headers
102				.iter()
103				.map(|s| s.as_str())
104				.collect::<Vec<_>>()
105				.join(", ");
106			headers.insert(
107				header::ACCESS_CONTROL_ALLOW_HEADERS,
108				headers_str.parse().unwrap(),
109			);
110		}
111
112		// 设置 Max-Age
113		if let Some(max_age) = self.max_age {
114			headers.insert(
115				header::ACCESS_CONTROL_MAX_AGE,
116				max_age.as_secs().to_string().parse().unwrap(),
117			);
118		}
119
120		// 设置 Credentials
121		if self.allow_credentials {
122			headers.insert(
123				header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
124				"true".parse().unwrap(),
125			);
126		}
127
128		response
129	}
130
131	/// 给正常响应添加 CORS 头
132	pub fn apply_headers(&self, response: &mut Response<Body>, origin: Option<&str>) {
133		// 改这里
134		let headers = response.headers_mut();
135
136		// 设置 Allow-Origin
137		match &self.allow_origins {
138			Some(origins) => {
139				if let Some(origin) = origin {
140					if origins.contains(origin) {
141						headers
142							.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
143					}
144				}
145			}
146			None => {
147				if let Some(origin) = origin {
148					headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
149				} else {
150					headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
151				}
152			}
153		}
154
155		// 设置 Expose-Headers
156		if !self.expose_headers.is_empty() {
157			let expose = self
158				.expose_headers
159				.iter()
160				.map(|s| s.as_str())
161				.collect::<Vec<_>>()
162				.join(", ");
163			headers.insert(
164				header::ACCESS_CONTROL_EXPOSE_HEADERS,
165				expose.parse().unwrap(),
166			);
167		}
168
169		// 设置 Credentials
170		if self.allow_credentials {
171			headers.insert(
172				header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
173				"true".parse().unwrap(),
174			);
175		}
176	}
177}