salvo_cors/lib.rs
1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::allow_credentials::AllowCredentials;
26pub use self::allow_headers::AllowHeaders;
27pub use self::allow_methods::AllowMethods;
28pub use self::allow_origin::AllowOrigin;
29pub use self::allow_private_network::AllowPrivateNetwork;
30pub use self::expose_headers::ExposeHeaders;
31pub use self::max_age::MaxAge;
32pub use self::vary::Vary;
33
34static WILDCARD: HeaderValue = HeaderValue::from_static("*");
35
36/// Represents a wildcard value (`*`) used with some CORS headers such as
37/// [`Cors::allow_methods`].
38#[derive(Debug, Clone, Copy)]
39#[must_use]
40pub struct Any;
41
42fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
43where
44 I: Iterator<Item = HeaderValue>,
45{
46 match iter.next() {
47 Some(fst) => {
48 let mut result = BytesMut::from(fst.as_bytes());
49 for val in iter {
50 result.reserve(val.len() + 1);
51 result.put_u8(b',');
52 result.extend_from_slice(val.as_bytes());
53 }
54
55 HeaderValue::from_maybe_shared(result.freeze()).ok()
56 }
57 None => None,
58 }
59}
60
61/// [`Cors`] middleware which adds headers for [CORS][mdn].
62///
63/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
64#[derive(Clone, Debug)]
65pub struct Cors {
66 allow_credentials: AllowCredentials,
67 allow_headers: AllowHeaders,
68 allow_methods: AllowMethods,
69 allow_origin: AllowOrigin,
70 allow_private_network: AllowPrivateNetwork,
71 expose_headers: ExposeHeaders,
72 max_age: MaxAge,
73 vary: Vary,
74}
75impl Default for Cors {
76 #[inline]
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl Cors {
83 /// Create new `Cors`.
84 #[inline]
85 #[must_use]
86 pub fn new() -> Self {
87 Self {
88 allow_credentials: Default::default(),
89 allow_headers: Default::default(),
90 allow_methods: Default::default(),
91 allow_origin: Default::default(),
92 allow_private_network: Default::default(),
93 expose_headers: Default::default(),
94 max_age: Default::default(),
95 vary: Default::default(),
96 }
97 }
98
99 /// A permissive configuration:
100 ///
101 /// - All request headers allowed.
102 /// - All methods allowed.
103 /// - All origins allowed.
104 /// - All headers exposed.
105 ///
106 /// # Security Warning
107 ///
108 /// **This configuration allows any website to make requests to your API.**
109 /// Only use this for:
110 /// - Public APIs that don't require authentication
111 /// - Development/testing environments
112 ///
113 /// For production APIs that require authentication, configure CORS explicitly
114 /// with specific allowed origins.
115 #[must_use]
116 pub fn permissive() -> Self {
117 Self::new()
118 .allow_headers(Any)
119 .allow_methods(Any)
120 .allow_origin(Any)
121 .expose_headers(Any)
122 }
123
124 /// A very permissive configuration:
125 ///
126 /// - **Credentials allowed.**
127 /// - The method received in `Access-Control-Request-Method` is sent back as an allowed method.
128 /// - The origin of the preflight request is sent back as an allowed origin.
129 /// - The header names received in `Access-Control-Request-Headers` are sent back as allowed
130 /// headers.
131 /// - No headers are currently exposed, but this may change in the future.
132 ///
133 /// # Security Warning
134 ///
135 /// **⚠️ DANGER: This configuration essentially disables CORS protection!**
136 ///
137 /// By enabling credentials AND mirroring the request origin, you are allowing
138 /// ANY website to:
139 /// - Make authenticated requests to your API
140 /// - Read response data including sensitive information
141 /// - Perform actions on behalf of logged-in users (CSRF attacks)
142 ///
143 /// **This should NEVER be used in production with authentication.**
144 ///
145 /// Only use this for:
146 /// - Local development where security is not a concern
147 /// - Internal tools on trusted networks
148 ///
149 /// For production, always configure explicit allowed origins:
150 /// ```ignore
151 /// Cors::new()
152 /// .allow_origin("https://your-frontend.com")
153 /// .allow_credentials(true)
154 /// ```
155 #[must_use]
156 pub fn very_permissive() -> Self {
157 tracing::warn!(
158 "Using Cors::very_permissive() - this disables CORS security and should not be used in production!"
159 );
160 Self::new()
161 .allow_credentials(true)
162 .allow_headers(AllowHeaders::mirror_request())
163 .allow_methods(AllowMethods::mirror_request())
164 .allow_origin(AllowOrigin::mirror_request())
165 }
166
167 /// Sets whether to add the `Access-Control-Allow-Credentials` header.
168 #[inline]
169 #[must_use]
170 pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
171 self.allow_credentials = allow_credentials.into();
172 self
173 }
174
175 /// Adds multiple headers to the list of allowed request headers.
176 ///
177 /// **Note**: These should match the values the browser sends via
178 /// `Access-Control-Request-Headers`, e.g.`content-type`.
179 ///
180 /// # Panics
181 ///
182 /// Panics if any of the headers are not a valid `http::header::HeaderName`.
183 #[inline]
184 #[must_use]
185 pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
186 self.allow_headers = headers.into();
187 self
188 }
189
190 /// Sets the `Access-Control-Max-Age` header.
191 ///
192 /// # Example
193 ///
194 /// ```
195 /// use std::time::Duration;
196 ///
197 /// use salvo_core::prelude::*;
198 /// use salvo_cors::Cors;
199 ///
200 /// let cors = Cors::new().max_age(30); // 30 seconds
201 /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
202 /// ```
203 #[inline]
204 #[must_use]
205 pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
206 self.max_age = max_age.into();
207 self
208 }
209
210 /// Adds multiple methods to the existing list of allowed request methods.
211 ///
212 /// # Panics
213 ///
214 /// Panics if the provided argument is not a valid `http::Method`.
215 #[inline]
216 #[must_use]
217 pub fn allow_methods<I>(mut self, methods: I) -> Self
218 where
219 I: Into<AllowMethods>,
220 {
221 self.allow_methods = methods.into();
222 self
223 }
224
225 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
226 /// ```
227 /// use salvo_core::http::HeaderValue;
228 /// use salvo_cors::Cors;
229 ///
230 /// let cors = Cors::new().allow_origin("http://example.com".parse::<HeaderValue>().unwrap());
231 /// ```
232 ///
233 /// Multiple origins can be allowed with
234 ///
235 /// ```
236 /// use salvo_cors::Cors;
237 ///
238 /// let origins = ["http://example.com", "http://api.example.com"];
239 ///
240 /// let cors = Cors::new().allow_origin(origins);
241 /// ```
242 ///
243 /// All origins can be allowed with
244 ///
245 /// ```
246 /// use salvo_cors::{Any, Cors};
247 ///
248 /// let cors = Cors::new().allow_origin(Any);
249 /// ```
250 ///
251 /// You can also use a closure
252 ///
253 /// ```
254 /// use salvo_core::http::HeaderValue;
255 /// use salvo_core::{Depot, Request};
256 /// use salvo_cors::{AllowOrigin, Cors};
257 ///
258 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
259 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
260 /// if origin?.as_bytes().ends_with(b".rust-lang.org") {
261 /// origin.cloned()
262 /// } else {
263 /// None
264 /// }
265 /// },
266 /// ));
267 /// ```
268 ///
269 /// You can also use an async closure, make sure all the values are owned
270 /// before passing into the future:
271 ///
272 /// ```
273 /// # #[derive(Clone)]
274 /// # struct Client;
275 /// # fn get_api_client() -> Client {
276 /// # Client
277 /// # }
278 /// # impl Client {
279 /// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
280 /// # vec![HeaderValue::from_static("http://example.com")]
281 /// # }
282 /// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
283 /// # vec![HeaderValue::from_static("http://example.com")]
284 /// # }
285 /// # }
286 /// use salvo_core::http::header::HeaderValue;
287 /// use salvo_core::{Depot, Request};
288 /// use salvo_cors::{AllowOrigin, Cors};
289 ///
290 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
291 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
292 /// let origin = origin.cloned();
293 /// async move {
294 /// let client = get_api_client();
295 /// // fetch list of origins that are allowed
296 /// let origins = client.fetch_allowed_origins().await;
297 /// if origins.contains(origin.as_ref()?) {
298 /// origin
299 /// } else {
300 /// None
301 /// }
302 /// }
303 /// },
304 /// ));
305 /// ```
306 ///
307 /// **Note** that multiple calls to this method will override any previous
308 /// calls.
309 ///
310 /// **Note** origin must contain http or https protocol name.
311 ///
312 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
313 #[inline]
314 #[must_use]
315 pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
316 self.allow_origin = origin.into();
317 self
318 }
319
320 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
321 ///
322 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
323 #[inline]
324 #[must_use]
325 pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
326 self.expose_headers = headers.into();
327 self
328 }
329
330 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
331 ///
332 /// ```
333 /// use salvo_cors::Cors;
334 ///
335 /// let cors = Cors::new().allow_private_network(true);
336 /// ```
337 ///
338 /// [wicg]: https://wicg.github.io/private-network-access/
339 #[must_use]
340 pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
341 where
342 T: Into<AllowPrivateNetwork>,
343 {
344 self.allow_private_network = allow_private_network.into();
345 self
346 }
347
348 /// Set the value(s) of the [`Vary`][mdn] header.
349 ///
350 /// In contrast to the other headers, this one has a non-empty default of
351 /// [`preflight_request_headers()`].
352 ///
353 /// You only need to set this is you want to remove some of these defaults,
354 /// or if you use a closure for one of the other headers and want to add a
355 /// vary header accordingly.
356 ///
357 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
358 #[must_use]
359 pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
360 self.vary = headers.into();
361 self
362 }
363
364 /// Returns a new `CorsHandler` using current cors settings.
365 pub fn into_handler(self) -> CorsHandler {
366 self.ensure_usable_cors_rules();
367 CorsHandler::new(self, CallNext::default())
368 }
369
370 fn ensure_usable_cors_rules(&self) {
371 if self.allow_credentials.is_true() {
372 assert!(
373 !self.allow_headers.is_wildcard(),
374 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
375 with `Access-Control-Allow-Headers: *`"
376 );
377
378 assert!(
379 !self.allow_methods.is_wildcard(),
380 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
381 with `Access-Control-Allow-Methods: *`"
382 );
383
384 assert!(
385 !self.allow_origin.is_wildcard(),
386 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
387 with `Access-Control-Allow-Origin: *`"
388 );
389
390 assert!(
391 !self.expose_headers.is_wildcard(),
392 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
393 with `Access-Control-Expose-Headers: *`"
394 );
395 }
396 }
397}
398
399/// Enum to control when to call next handler.
400#[non_exhaustive]
401#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
402pub enum CallNext {
403 /// Call next handlers before [`CorsHandler`] write data to response.
404 #[default]
405 Before,
406 /// Call next handlers after [`CorsHandler`] write data to response.
407 After,
408}
409
410/// CorsHandler
411#[derive(Clone, Debug)]
412pub struct CorsHandler {
413 cors: Cors,
414 call_next: CallNext,
415}
416impl CorsHandler {
417 /// Create a new `CorsHandler`.
418 pub fn new(cors: Cors, call_next: CallNext) -> Self {
419 Self { cors, call_next }
420 }
421}
422
423#[async_trait]
424impl Handler for CorsHandler {
425 async fn handle(
426 &self,
427 req: &mut Request,
428 depot: &mut Depot,
429 res: &mut Response,
430 ctrl: &mut FlowCtrl,
431 ) {
432 if self.call_next == CallNext::Before {
433 ctrl.call_next(req, depot, res).await;
434 }
435
436 let origin = req.headers().get(&header::ORIGIN);
437 let mut headers = HeaderMap::new();
438
439 // These headers are applied to both preflight and subsequent regular CORS requests:
440 // https://fetch.spec.whatwg.org/#http-responses
441 headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
442 headers.extend(
443 self.cors
444 .allow_credentials
445 .to_header(origin, req, depot)
446 .await,
447 );
448 headers.extend(
449 self.cors
450 .allow_private_network
451 .to_header(origin, req, depot)
452 .await,
453 );
454
455 let mut vary_headers = self.cors.vary.values();
456 if let Some(first) = vary_headers.next() {
457 let mut header = match headers.entry(header::VARY) {
458 header::Entry::Occupied(_) => {
459 unreachable!("no vary header inserted up to this point")
460 }
461 header::Entry::Vacant(v) => v.insert_entry(first),
462 };
463
464 for val in vary_headers {
465 header.append(val);
466 }
467 }
468
469 // Return results immediately upon preflight request
470 if req.method() == Method::OPTIONS {
471 // These headers are applied only to preflight requests
472 headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
473 headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
474 headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
475 res.status_code = Some(StatusCode::NO_CONTENT);
476 } else {
477 // This header is applied only to non-preflight requests
478 headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
479 }
480 res.headers_mut().extend(headers);
481
482 if self.call_next == CallNext::After {
483 ctrl.call_next(req, depot, res).await;
484 }
485 }
486}
487
488/// Iterator over the three request headers that may be involved in a CORS preflight request.
489///
490/// This is the default set of header names returned in the `vary` header
491pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
492 [
493 header::ORIGIN,
494 header::ACCESS_CONTROL_REQUEST_METHOD,
495 header::ACCESS_CONTROL_REQUEST_HEADERS,
496 ]
497 .into_iter()
498}
499
500#[cfg(test)]
501mod tests {
502 use salvo_core::http::header::*;
503 use salvo_core::prelude::*;
504 use salvo_core::test::TestClient;
505
506 use super::*;
507
508 #[tokio::test]
509 async fn test_cors() {
510 let cors_handler = Cors::new()
511 .allow_origin("https://salvo.rs")
512 .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
513 .allow_headers(vec![
514 "CONTENT-TYPE",
515 "Access-Control-Request-Method",
516 "Access-Control-Allow-Origin",
517 "Access-Control-Allow-Headers",
518 "Access-Control-Max-Age",
519 ])
520 .into_handler();
521
522 #[handler]
523 async fn hello() -> &'static str {
524 "hello"
525 }
526
527 let router = Router::new()
528 .hoop(cors_handler)
529 .push(Router::with_path("hello").goal(hello));
530 let service = Service::new(router);
531
532 async fn options_access(service: &Service, origin: &str) -> Response {
533 TestClient::options("http://127.0.0.1:5801/hello")
534 .add_header("Origin", origin, true)
535 .add_header("Access-Control-Request-Method", "POST", true)
536 .add_header("Access-Control-Request-Headers", "Content-Type", true)
537 .send(service)
538 .await
539 }
540
541 let res = TestClient::options("https://salvo.rs").send(&service).await;
542 assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
543
544 let res = options_access(&service, "https://salvo.rs").await;
545 let headers = res.headers();
546 assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
547 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
548
549 let res = TestClient::options("https://google.com")
550 .send(&service)
551 .await;
552 let headers = res.headers();
553 assert!(
554 headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
555 "POST, GET, DELETE, OPTIONS"
556 );
557 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
558 }
559}