1use crate::response::{JcBody, Response};
7use http::{HeaderValue, Method, StatusCode, header};
8use std::time::Duration;
9
10#[derive(Clone, Debug)]
12pub enum CorsOrigins {
13 Any,
16 List(Vec<String>),
18}
19
20impl CorsOrigins {
21 pub fn any() -> Self {
22 Self::Any
23 }
24 pub fn list<I, S>(origins: I) -> Self
25 where
26 I: IntoIterator<Item = S>,
27 S: Into<String>,
28 {
29 Self::List(origins.into_iter().map(Into::into).collect())
30 }
31}
32
33#[derive(Clone, Debug)]
36pub struct CorsConfig {
37 origins: CorsOrigins,
38 methods: Vec<http::Method>, headers: Vec<String>, expose: Vec<String>,
41 allow_credentials: bool,
42 max_age: Option<Duration>,
43}
44
45impl CorsConfig {
46 pub fn new(origins: CorsOrigins) -> Self {
47 Self {
48 origins,
49 methods: Vec::new(),
50 headers: Vec::new(),
51 expose: Vec::new(),
52 allow_credentials: false,
53 max_age: None,
54 }
55 }
56 pub fn allow_credentials(mut self, yes: bool) -> Self {
57 self.allow_credentials = yes;
58 self
59 }
60 pub fn max_age(mut self, d: Duration) -> Self {
61 self.max_age = Some(d);
62 self
63 }
64 pub fn allow_methods<I: IntoIterator<Item = http::Method>>(mut self, m: I) -> Self {
65 self.methods = m.into_iter().collect();
66 self
67 }
68 pub fn allow_headers<I: IntoIterator<Item = S>, S: Into<String>>(mut self, h: I) -> Self {
69 self.headers = h.into_iter().map(Into::into).collect();
70 self
71 }
72 pub fn expose_headers<I: IntoIterator<Item = S>, S: Into<String>>(mut self, h: I) -> Self {
73 self.expose = h.into_iter().map(Into::into).collect();
74 self
75 }
76
77 pub fn allow_credentials_enabled(&self) -> bool {
79 self.allow_credentials
80 }
81
82 pub fn allows_origin(&self, origin: &str) -> bool {
84 match &self.origins {
85 CorsOrigins::Any => true,
86 CorsOrigins::List(list) => list.iter().any(|o| o == origin),
87 }
88 }
89
90 pub(crate) fn cfg_methods(&self) -> &[http::Method] {
92 &self.methods
93 }
94 pub(crate) fn cfg_headers(&self) -> &[String] {
96 &self.headers
97 }
98 pub(crate) fn cfg_max_age(&self) -> Option<std::time::Duration> {
100 self.max_age
101 }
102 pub(crate) fn credentials(&self) -> bool {
104 self.allow_credentials
105 }
106 pub(crate) fn cfg_expose(&self) -> &[String] {
108 &self.expose
109 }
110
111 pub(crate) fn validate(&self) -> crate::Result<()> {
114 if self.allow_credentials && matches!(self.origins, CorsOrigins::Any) {
115 return Err(crate::Error::internal(
116 "CORS misconfiguration: allow_credentials(true) cannot be combined with CorsOrigins::any() — list explicit origins",
117 ));
118 }
119 Ok(())
120 }
121}
122
123pub(crate) fn is_preflight(parts: &http::request::Parts) -> bool {
127 parts.method == Method::OPTIONS
128 && parts.headers.contains_key(header::ORIGIN)
129 && parts
130 .headers
131 .contains_key(header::ACCESS_CONTROL_REQUEST_METHOD)
132}
133
134pub(crate) fn preflight_response(
139 config: &CorsConfig,
140 origin: &str,
141 request_headers: Option<&str>,
142 allowed_methods: &[Method],
143) -> Response {
144 let mut r = http::Response::new(JcBody::empty());
145 *r.status_mut() = StatusCode::NO_CONTENT;
146 let h = r.headers_mut();
147 if let Ok(v) = HeaderValue::from_str(origin) {
148 h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
149 h.insert(header::VARY, HeaderValue::from_static("Origin"));
150 }
151 if config.credentials() {
152 h.insert(
153 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
154 HeaderValue::from_static("true"),
155 );
156 }
157 let methods = if config.cfg_methods().is_empty() {
158 allowed_methods
159 } else {
160 config.cfg_methods()
161 };
162 let methods_joined = methods
163 .iter()
164 .map(Method::as_str)
165 .collect::<Vec<_>>()
166 .join(", ");
167 if let Ok(v) = HeaderValue::from_str(&methods_joined) {
168 h.insert(header::ACCESS_CONTROL_ALLOW_METHODS, v);
169 }
170 let allow_headers = if config.cfg_headers().is_empty() {
171 request_headers.map(str::to_string)
172 } else {
173 Some(config.cfg_headers().join(", "))
174 };
175 if let Some(hdrs) = allow_headers
176 && let Ok(v) = HeaderValue::from_str(&hdrs)
177 {
178 h.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
179 }
180 if let Some(age) = config.cfg_max_age()
181 && let Ok(v) = HeaderValue::from_str(&age.as_secs().to_string())
182 {
183 h.insert(header::ACCESS_CONTROL_MAX_AGE, v);
184 }
185 r
186}
187
188pub(crate) fn apply_cors(res: &mut Response, origin: Option<&HeaderValue>, config: &CorsConfig) {
194 let Some(origin) = origin.and_then(|v| v.to_str().ok()) else {
195 return;
196 };
197 if !config.allows_origin(origin) {
198 return;
199 }
200 let Ok(origin_val) = HeaderValue::from_str(origin) else {
201 return;
202 };
203 let h = res.headers_mut();
204 if !h.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN) {
205 h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
206 }
207 if config.credentials() && !h.contains_key(header::ACCESS_CONTROL_ALLOW_CREDENTIALS) {
208 h.insert(
209 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
210 HeaderValue::from_static("true"),
211 );
212 }
213 if !config.cfg_expose().is_empty()
214 && !h.contains_key(header::ACCESS_CONTROL_EXPOSE_HEADERS)
215 && let Ok(v) = HeaderValue::from_str(&config.cfg_expose().join(", "))
216 {
217 h.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, v);
218 }
219 let has_origin_vary = h.get_all(header::VARY).iter().any(|v| {
222 v.to_str()
223 .map(|s| {
224 s.split(',')
225 .any(|p| p.trim().eq_ignore_ascii_case("origin"))
226 })
227 .unwrap_or(false)
228 });
229 if !has_origin_vary {
230 h.append(header::VARY, HeaderValue::from_static("Origin"));
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn config_builder_shapes_origins_and_credentials() {
240 let c = CorsConfig::new(CorsOrigins::list(["https://app.example"]))
241 .allow_credentials(true)
242 .max_age(std::time::Duration::from_secs(600));
243 assert!(c.allows_origin("https://app.example"));
244 assert!(!c.allows_origin("https://evil.example"));
245 assert!(c.allow_credentials_enabled());
246 }
247
248 #[test]
249 fn any_origin_allows_everything() {
250 let c = CorsConfig::new(CorsOrigins::any());
251 assert!(c.allows_origin("https://whatever.example"));
252 }
253}