Skip to main content

allowthem_server/
cors.rs

1use std::collections::HashSet;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header};
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8
9use allowthem_core::AllowThem;
10
11/// Bridges `State<AllowThem>` into request extensions so that
12/// [`cors_middleware`] (which reads from extensions) works in standalone mode.
13/// In SaaS mode the tenant router populates extensions directly; this shim is
14/// not used there.
15pub async fn inject_ath_into_extensions(
16    State(ath): State<AllowThem>,
17    mut req: Request<Body>,
18    next: Next,
19) -> Response {
20    req.extensions_mut().insert(ath);
21    next.run(req).await
22}
23
24/// Dynamic CORS middleware for OIDC endpoints.
25///
26/// The allowed-origin set is built per-request from all active applications'
27/// redirect URIs. Requests without an `Origin` header are passed through
28/// unchanged. Returns 500 if `AllowThem` is absent from request extensions
29/// (server misconfiguration — the inject shim was not applied).
30pub(crate) async fn cors_middleware(req: Request<Body>, next: Next) -> Response {
31    let origin_header = req.headers().get(header::ORIGIN).cloned();
32
33    let Some(origin_val) = origin_header else {
34        return next.run(req).await;
35    };
36
37    let origin_str = match origin_val.to_str() {
38        Ok(s) => s.to_owned(),
39        Err(_) => return StatusCode::BAD_REQUEST.into_response(),
40    };
41
42    let ath = match req.extensions().get::<AllowThem>().cloned() {
43        Some(a) => a,
44        None => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
45    };
46
47    let is_preflight = req.method() == Method::OPTIONS;
48    let allow_set = build_allow_set(&ath).await;
49
50    if !allow_set.contains(&origin_str) {
51        let mut res = StatusCode::FORBIDDEN.into_response();
52        res.headers_mut()
53            .insert(header::VARY, HeaderValue::from_static("Origin"));
54        return res;
55    }
56
57    if is_preflight {
58        let mut headers = HeaderMap::new();
59        headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
60        headers.insert(
61            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
62            HeaderValue::from_static("false"),
63        );
64        headers.insert(
65            header::ACCESS_CONTROL_ALLOW_METHODS,
66            HeaderValue::from_static("GET, POST, OPTIONS"),
67        );
68        headers.insert(
69            header::ACCESS_CONTROL_ALLOW_HEADERS,
70            HeaderValue::from_static("Authorization, Content-Type"),
71        );
72        headers.insert(
73            header::ACCESS_CONTROL_MAX_AGE,
74            HeaderValue::from_static("600"),
75        );
76        headers.insert(header::VARY, HeaderValue::from_static("Origin"));
77        return (StatusCode::NO_CONTENT, headers).into_response();
78    }
79
80    let mut res = next.run(req).await;
81    let headers = res.headers_mut();
82    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_val);
83    headers.insert(
84        header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
85        HeaderValue::from_static("false"),
86    );
87    headers.insert(header::VARY, HeaderValue::from_static("Origin"));
88    res
89}
90
91async fn build_allow_set(ath: &AllowThem) -> HashSet<String> {
92    let apps = match ath.db().list_applications().await {
93        Ok(a) => a,
94        Err(_) => return HashSet::new(),
95    };
96    apps.iter()
97        .filter(|app| app.is_active)
98        .flat_map(|app| app.redirect_uri_list().ok().unwrap_or_default())
99        .filter_map(|uri| origin_of(uri.trim()))
100        .collect()
101}
102
103fn origin_of(uri: &str) -> Option<String> {
104    let parsed = url::Url::parse(uri).ok()?;
105    match parsed.origin() {
106        url::Origin::Opaque(_) => None,
107        _ => Some(parsed.origin().ascii_serialization()),
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use allowthem_core::AllowThemBuilder;
115    use allowthem_core::applications::CreateApplicationParams;
116    use axum::Router;
117    use axum::http::StatusCode;
118    use axum::routing::get;
119    use tower::ServiceExt;
120
121    async fn dummy() -> StatusCode {
122        StatusCode::OK
123    }
124
125    async fn make_test_app(redirect_uris: Vec<String>) -> Router {
126        let ath = AllowThemBuilder::new("sqlite::memory:")
127            .build()
128            .await
129            .unwrap();
130
131        if !redirect_uris.is_empty() {
132            ath.db()
133                .create_application(CreateApplicationParams {
134                    name: "TestApp".to_string(),
135                    client_type: allowthem_core::ClientType::Confidential,
136                    redirect_uris,
137                    is_trusted: false,
138                    created_by: None,
139                    logo_url: None,
140                    primary_color: None,
141                    accent_hex: None,
142                    accent_ink: None,
143                    forced_mode: None,
144                    font_css_url: None,
145                    font_family: None,
146                    splash_text: None,
147                    splash_image_url: None,
148                    splash_primitive: None,
149                    splash_url: None,
150                    shader_cell_scale: None,
151                })
152                .await
153                .unwrap();
154        }
155
156        Router::new()
157            .route("/test", get(dummy).post(dummy))
158            .layer(axum::middleware::from_fn(cors_middleware))
159            .layer(axum::middleware::from_fn_with_state(
160                ath.clone(),
161                inject_ath_into_extensions,
162            ))
163    }
164
165    #[tokio::test]
166    async fn t1_allowed_origin_passes_through() {
167        let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
168        let req = Request::builder()
169            .uri("/test")
170            .header("Origin", "https://app.example.com")
171            .body(Body::empty())
172            .unwrap();
173        let resp = app.oneshot(req).await.unwrap();
174        assert_eq!(resp.status(), StatusCode::OK);
175        assert_eq!(
176            resp.headers().get("access-control-allow-origin").unwrap(),
177            "https://app.example.com"
178        );
179        assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
180    }
181
182    #[tokio::test]
183    async fn t2_disallowed_origin_returns_403() {
184        let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
185        let req = Request::builder()
186            .uri("/test")
187            .header("Origin", "https://evil.example.com")
188            .body(Body::empty())
189            .unwrap();
190        let resp = app.oneshot(req).await.unwrap();
191        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
192        assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
193        assert!(resp.headers().get("access-control-allow-origin").is_none());
194    }
195
196    #[tokio::test]
197    async fn t3_preflight_allowed_origin_returns_204() {
198        let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
199        let req = Request::builder()
200            .method("OPTIONS")
201            .uri("/test")
202            .header("Origin", "https://app.example.com")
203            .body(Body::empty())
204            .unwrap();
205        let resp = app.oneshot(req).await.unwrap();
206        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
207        assert_eq!(
208            resp.headers().get("access-control-allow-origin").unwrap(),
209            "https://app.example.com"
210        );
211        assert!(resp.headers().get("access-control-allow-methods").is_some());
212        assert!(resp.headers().get("access-control-allow-headers").is_some());
213        assert_eq!(resp.headers().get("access-control-max-age").unwrap(), "600");
214        assert_eq!(resp.headers().get("vary").unwrap(), "Origin");
215    }
216
217    #[tokio::test]
218    async fn t4_preflight_disallowed_origin_returns_403() {
219        let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
220        let req = Request::builder()
221            .method("OPTIONS")
222            .uri("/test")
223            .header("Origin", "https://evil.example.com")
224            .body(Body::empty())
225            .unwrap();
226        let resp = app.oneshot(req).await.unwrap();
227        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
228        assert!(resp.headers().get("access-control-allow-origin").is_none());
229    }
230
231    #[tokio::test]
232    async fn t5_no_origin_passes_through_unchanged() {
233        let app = make_test_app(vec!["https://app.example.com/callback".into()]).await;
234        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
235        let resp = app.oneshot(req).await.unwrap();
236        assert_eq!(resp.status(), StatusCode::OK);
237        assert!(resp.headers().get("access-control-allow-origin").is_none());
238        assert!(resp.headers().get("vary").is_none());
239    }
240
241    #[tokio::test]
242    async fn t6_empty_application_list_rejects_all_origins() {
243        let app = make_test_app(vec![]).await;
244        let req = Request::builder()
245            .uri("/test")
246            .header("Origin", "https://any.example.com")
247            .body(Body::empty())
248            .unwrap();
249        let resp = app.oneshot(req).await.unwrap();
250        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
251    }
252}