1use std::collections::HashSet;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header};
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8
9use allowthem_core::AllowThem;
10
11pub async fn inject_ath_into_extensions(
16 State(ath): State<AllowThem>,
17 mut req: Request<Body>,
18 next: Next,
19) -> Response {
20 req.extensions_mut().insert(ath);
21 next.run(req).await
22}
23
24pub(crate) async fn cors_middleware(req: Request<Body>, next: Next) -> Response {
31 let origin_header = req.headers().get(header::ORIGIN).cloned();
32
33 let Some(origin_val) = origin_header else {
34 return next.run(req).await;
35 };
36
37 let origin_str = match origin_val.to_str() {
38 Ok(s) => s.to_owned(),
39 Err(_) => return StatusCode::BAD_REQUEST.into_response(),
40 };
41
42 let ath = match req.extensions().get::<AllowThem>().cloned() {
43 Some(a) => a,
44 None => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
45 };
46
47 let is_preflight = req.method() == Method::OPTIONS;
48 let allow_set = build_allow_set(&ath).await;
49
50 if !allow_set.contains(&origin_str) {
51 let mut res = StatusCode::FORBIDDEN.into_response();
52 res.headers_mut()
53 .insert(header::VARY, HeaderValue::from_static("Origin"));
54 return res;
55 }
56
57 if is_preflight {
58 let mut headers = HeaderMap::new();
59 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
60 headers.insert(
61 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
62 HeaderValue::from_static("false"),
63 );
64 headers.insert(
65 header::ACCESS_CONTROL_ALLOW_METHODS,
66 HeaderValue::from_static("GET, POST, OPTIONS"),
67 );
68 headers.insert(
69 header::ACCESS_CONTROL_ALLOW_HEADERS,
70 HeaderValue::from_static("Authorization, Content-Type"),
71 );
72 headers.insert(
73 header::ACCESS_CONTROL_MAX_AGE,
74 HeaderValue::from_static("600"),
75 );
76 headers.insert(header::VARY, HeaderValue::from_static("Origin"));
77 return (StatusCode::NO_CONTENT, headers).into_response();
78 }
79
80 let mut res = next.run(req).await;
81 let headers = res.headers_mut();
82 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
83 headers.insert(
84 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
85 HeaderValue::from_static("false"),
86 );
87 headers.insert(header::VARY, HeaderValue::from_static("Origin"));
88 res
89}
90
91async fn build_allow_set(ath: &AllowThem) -> HashSet<String> {
92 let apps = match ath.db().list_applications().await {
93 Ok(a) => a,
94 Err(_) => return HashSet::new(),
95 };
96 apps.iter()
97 .filter(|app| app.is_active)
98 .flat_map(|app| app.redirect_uri_list().ok().unwrap_or_default())
99 .filter_map(|uri| origin_of(uri.trim()))
100 .collect()
101}
102
103fn origin_of(uri: &str) -> Option<String> {
104 let parsed = url::Url::parse(uri).ok()?;
105 match parsed.origin() {
106 url::Origin::Opaque(_) => None,
107 _ => Some(parsed.origin().ascii_serialization()),
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use allowthem_core::AllowThemBuilder;
115 use allowthem_core::applications::CreateApplicationParams;
116 use axum::Router;
117 use axum::http::StatusCode;
118 use axum::routing::get;
119 use tower::ServiceExt;
120
121 async fn dummy() -> StatusCode {
122 StatusCode::OK
123 }
124
125 async fn make_test_app(redirect_uris: Vec<String>) -> Router {
126 let ath = AllowThemBuilder::new("sqlite::memory:")
127 .build()
128 .await
129 .unwrap();
130
131 if !redirect_uris.is_empty() {
132 ath.db()
133 .create_application(CreateApplicationParams {
134 name: "TestApp".to_string(),
135 client_type: allowthem_core::ClientType::Confidential,
136 redirect_uris,
137 is_trusted: false,
138 created_by: None,
139 logo_url: None,
140 primary_color: None,
141 accent_hex: None,
142 accent_ink: None,
143 forced_mode: None,
144 font_css_url: None,
145 font_family: None,
146 splash_text: None,
147 splash_image_url: None,
148 splash_primitive: None,
149 splash_url: None,
150 shader_cell_scale: None,
151 })
152 .await
153 .unwrap();
154 }
155
156 Router::new()
157 .route("/test", get(dummy).post(dummy))
158 .layer(axum::middleware::from_fn(cors_middleware))
159 .layer(axum::middleware::from_fn_with_state(
160 ath.clone(),
161 inject_ath_into_extensions,
162 ))
163 }
164
165 #[tokio::test]
166 async fn t1_allowed_origin_passes_through() {
167 let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
168 let req = Request::builder()
169 .uri("/test")
170 .header("Origin", "https://app.example.com")
171 .body(Body::empty())
172 .unwrap();
173 let resp = app.oneshot(req).await.unwrap();
174 assert_eq!(resp.status(), StatusCode::OK);
175 assert_eq!(
176 resp.headers().get("access-control-allow-origin").unwrap(),
177 "https://app.example.com"
178 );
179 assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
180 }
181
182 #[tokio::test]
183 async fn t2_disallowed_origin_returns_403() {
184 let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
185 let req = Request::builder()
186 .uri("/test")
187 .header("Origin", "https://evil.example.com")
188 .body(Body::empty())
189 .unwrap();
190 let resp = app.oneshot(req).await.unwrap();
191 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
192 assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
193 assert!(resp.headers().get("access-control-allow-origin").is_none());
194 }
195
196 #[tokio::test]
197 async fn t3_preflight_allowed_origin_returns_204() {
198 let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
199 let req = Request::builder()
200 .method("OPTIONS")
201 .uri("/test")
202 .header("Origin", "https://app.example.com")
203 .body(Body::empty())
204 .unwrap();
205 let resp = app.oneshot(req).await.unwrap();
206 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
207 assert_eq!(
208 resp.headers().get("access-control-allow-origin").unwrap(),
209 "https://app.example.com"
210 );
211 assert!(resp.headers().get("access-control-allow-methods").is_some());
212 assert!(resp.headers().get("access-control-allow-headers").is_some());
213 assert_eq!(resp.headers().get("access-control-max-age").unwrap(), "600");
214 assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
215 }
216
217 #[tokio::test]
218 async fn t4_preflight_disallowed_origin_returns_403() {
219 let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
220 let req = Request::builder()
221 .method("OPTIONS")
222 .uri("/test")
223 .header("Origin", "https://evil.example.com")
224 .body(Body::empty())
225 .unwrap();
226 let resp = app.oneshot(req).await.unwrap();
227 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
228 assert!(resp.headers().get("access-control-allow-origin").is_none());
229 }
230
231 #[tokio::test]
232 async fn t5_no_origin_passes_through_unchanged() {
233 let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
234 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
235 let resp = app.oneshot(req).await.unwrap();
236 assert_eq!(resp.status(), StatusCode::OK);
237 assert!(resp.headers().get("access-control-allow-origin").is_none());
238 assert!(resp.headers().get("vary").is_none());
239 }
240
241 #[tokio::test]
242 async fn t6_empty_application_list_rejects_all_origins() {
243 let app = make_test_app(vec![]).await;
244 let req = Request::builder()
245 .uri("/test")
246 .header("Origin", "https://any.example.com")
247 .body(Body::empty())
248 .unwrap();
249 let resp = app.oneshot(req).await.unwrap();
250 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
251 }
252}