arcly_http_core/web/
cors.rs1use axum::body::Body;
18use axum::extract::Request;
19use axum::http::{HeaderValue, Method};
20use axum::middleware::Next;
21use axum::response::Response;
22
23#[derive(Clone, Debug)]
25#[non_exhaustive]
26pub struct CorsConfig {
27 pub allow_origins: Vec<String>,
29 pub allow_methods: String,
30 pub allow_headers: String,
31 pub allow_credentials: bool,
32 pub max_age_secs: u32,
33}
34
35impl Default for CorsConfig {
36 fn default() -> Self {
37 Self {
38 allow_origins: vec![],
39 allow_methods: "GET, POST, PUT, PATCH, DELETE, OPTIONS".into(),
40 allow_headers: "content-type, authorization, x-request-id, x-tenant-id, \
41 idempotency-key, traceparent"
42 .into(),
43 allow_credentials: false,
44 max_age_secs: 600,
45 }
46 }
47}
48
49impl CorsConfig {
50 pub fn for_origins<I, S>(origins: I) -> Self
53 where
54 I: IntoIterator<Item = S>,
55 S: Into<String>,
56 {
57 Self {
58 allow_origins: origins.into_iter().map(Into::into).collect(),
59 allow_credentials: true,
60 ..Default::default()
61 }
62 }
63
64 pub fn allow_origins<I, S>(mut self, origins: I) -> Self
65 where
66 I: IntoIterator<Item = S>,
67 S: Into<String>,
68 {
69 self.allow_origins = origins.into_iter().map(Into::into).collect();
70 self
71 }
72 pub fn allow_methods(mut self, v: impl Into<String>) -> Self {
73 self.allow_methods = v.into();
74 self
75 }
76 pub fn allow_headers(mut self, v: impl Into<String>) -> Self {
77 self.allow_headers = v.into();
78 self
79 }
80 pub fn allow_credentials(mut self, v: bool) -> Self {
81 self.allow_credentials = v;
82 self
83 }
84 pub fn max_age_secs(mut self, v: u32) -> Self {
85 self.max_age_secs = v;
86 self
87 }
88
89 fn allowed_origin(&self, origin: &str) -> Option<String> {
91 let any = self.allow_origins.iter().any(|o| o == "*");
92 if any && !self.allow_credentials {
93 return Some("*".to_owned());
94 }
95 if any || self.allow_origins.iter().any(|o| o == origin) {
96 return Some(origin.to_owned());
97 }
98 None
99 }
100}
101
102#[doc(hidden)]
103pub async fn apply_cors(cfg: std::sync::Arc<CorsConfig>, req: Request, next: Next) -> Response {
104 let origin = req
105 .headers()
106 .get("origin")
107 .and_then(|v| v.to_str().ok())
108 .map(str::to_owned);
109
110 let allowed = origin.as_deref().and_then(|o| cfg.allowed_origin(o));
111
112 let is_preflight = req.method() == Method::OPTIONS
114 && req.headers().contains_key("access-control-request-method");
115 if is_preflight {
116 let Some(echo) = allowed else {
117 return Response::builder()
119 .status(403)
120 .body(Body::empty())
121 .expect("static preflight denial");
122 };
123 let mut resp = Response::builder()
124 .status(204)
125 .body(Body::empty())
126 .expect("static preflight response");
127 set_cors_headers(resp.headers_mut(), &cfg, &echo);
128 resp.headers_mut().insert(
129 "access-control-max-age",
130 HeaderValue::from_str(&cfg.max_age_secs.to_string()).expect("numeric"),
131 );
132 return resp;
133 }
134
135 let mut resp = next.run(req).await;
136 if let Some(echo) = allowed {
137 set_cors_headers(resp.headers_mut(), &cfg, &echo);
138 }
139 resp
140}
141
142fn set_cors_headers(headers: &mut axum::http::HeaderMap, cfg: &CorsConfig, origin: &str) {
143 if let Ok(v) = HeaderValue::from_str(origin) {
144 headers.insert("access-control-allow-origin", v);
145 }
146 if let Ok(v) = HeaderValue::from_str(&cfg.allow_methods) {
147 headers.insert("access-control-allow-methods", v);
148 }
149 if let Ok(v) = HeaderValue::from_str(&cfg.allow_headers) {
150 headers.insert("access-control-allow-headers", v);
151 }
152 if cfg.allow_credentials {
153 headers.insert(
154 "access-control-allow-credentials",
155 HeaderValue::from_static("true"),
156 );
157 }
158 headers.append("vary", HeaderValue::from_static("origin"));
159}