Skip to main content

allowthem_server/
middleware.rs

1use axum::body::Body;
2use axum::extract::FromRef;
3use axum::http::{Request, StatusCode, header::COOKIE};
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use serde_json::json;
7
8use std::sync::Arc;
9
10use allowthem_core::{AuthClient, AuthError, PermissionName, RoleName, User, parse_session_cookie};
11
12/// Axum middleware that requires a valid authenticated session.
13///
14/// Validates the session cookie, fetches the user, and inserts the [`User`]
15/// into request extensions so downstream handlers can access it cheaply via
16/// `Extension<User>`. Returns 401 JSON on any authentication failure.
17///
18/// Apply to a route group with `axum::middleware::from_fn_with_state(ath, require_auth)`.
19pub async fn require_auth<S>(
20    state: axum::extract::State<S>,
21    mut request: Request<Body>,
22    next: Next,
23) -> Response
24where
25    Arc<dyn AuthClient>: FromRef<S>,
26    S: Send + Sync + Clone,
27{
28    let client = <Arc<dyn AuthClient>>::from_ref(&*state);
29    // Clone the headers out before any await so we don't hold &Request<Body>
30    // (Body is not Sync) across an await point.
31    let headers = request.headers().clone();
32
33    let user = match authenticate(&*client, &headers).await {
34        Ok(u) => u,
35        Err(r) => return r,
36    };
37
38    request.extensions_mut().insert(user);
39    next.run(request).await
40}
41
42/// Middleware factory that requires the authenticated user to have a specific role.
43///
44/// Builds on `require_auth`: first validates the session (inserting `User` into
45/// extensions), then checks the role. Returns 401 if not authenticated, 403 if
46/// authenticated but missing the role.
47///
48/// Usage:
49/// ```ignore
50/// use axum::middleware;
51///
52/// let app = Router::new()
53///     .route("/admin", get(handler))
54///     .layer(middleware::from_fn_with_state(
55///         ath.clone(),
56///         require_role("admin"),
57///     ));
58/// ```
59pub fn require_role<S>(
60    role: impl Into<String>,
61) -> impl Fn(
62    axum::extract::State<S>,
63    Request<Body>,
64    Next,
65) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
66+ Clone
67+ Send
68+ 'static
69where
70    Arc<dyn AuthClient>: FromRef<S>,
71    S: Send + Sync + Clone + 'static,
72{
73    let role_name = role.into();
74    move |state, request, next| {
75        let role_name = role_name.clone();
76        Box::pin(require_role_inner(state, request, next, role_name))
77    }
78}
79
80async fn require_role_inner<S>(
81    state: axum::extract::State<S>,
82    mut request: Request<Body>,
83    next: Next,
84    role_name: String,
85) -> Response
86where
87    Arc<dyn AuthClient>: FromRef<S>,
88    S: Send + Sync + Clone,
89{
90    let client = <Arc<dyn AuthClient>>::from_ref(&*state);
91    let headers = request.headers().clone();
92
93    let user = match authenticate(&*client, &headers).await {
94        Ok(u) => u,
95        Err(r) => return r,
96    };
97
98    let rn = RoleName::new(role_name);
99    match client.check_role(&user.id, &rn).await {
100        Ok(true) => {}
101        Ok(false) => {
102            return (
103                StatusCode::FORBIDDEN,
104                axum::Json(json!({"error": "forbidden"})),
105            )
106                .into_response();
107        }
108        Err(e) => return internal_error(e),
109    }
110
111    request.extensions_mut().insert(user);
112    next.run(request).await
113}
114
115/// Middleware factory that requires the authenticated user to have a specific permission.
116///
117/// Works identically to [`require_role`] but checks permissions instead of roles.
118/// Permissions are checked via both direct assignment and role membership.
119///
120/// Returns 401 if not authenticated, 403 if authenticated but missing the permission.
121pub fn require_permission<S>(
122    permission: impl Into<String>,
123) -> impl Fn(
124    axum::extract::State<S>,
125    Request<Body>,
126    Next,
127) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
128+ Clone
129+ Send
130+ 'static
131where
132    Arc<dyn AuthClient>: FromRef<S>,
133    S: Send + Sync + Clone + 'static,
134{
135    let perm_name = permission.into();
136    move |state, request, next| {
137        let perm_name = perm_name.clone();
138        Box::pin(require_permission_inner(state, request, next, perm_name))
139    }
140}
141
142async fn require_permission_inner<S>(
143    state: axum::extract::State<S>,
144    mut request: Request<Body>,
145    next: Next,
146    perm_name: String,
147) -> Response
148where
149    Arc<dyn AuthClient>: FromRef<S>,
150    S: Send + Sync + Clone,
151{
152    let client = <Arc<dyn AuthClient>>::from_ref(&*state);
153    let headers = request.headers().clone();
154
155    let user = match authenticate(&*client, &headers).await {
156        Ok(u) => u,
157        Err(r) => return r,
158    };
159
160    let pn = PermissionName::new(perm_name);
161    match client.check_permission(&user.id, &pn).await {
162        Ok(true) => {}
163        Ok(false) => {
164            return (
165                StatusCode::FORBIDDEN,
166                axum::Json(json!({"error": "forbidden"})),
167            )
168                .into_response();
169        }
170        Err(e) => return internal_error(e),
171    }
172
173    request.extensions_mut().insert(user);
174    next.run(request).await
175}
176
177/// Shared authentication logic: parse cookie and validate session.
178///
179/// Takes the headers directly so the caller does not hold a `&Request<Body>` reference
180/// across await points (Body is not Sync).
181///
182/// Returns the active `User` on success, or an `IntoResponse` error response.
183async fn authenticate(
184    client: &dyn AuthClient,
185    headers: &axum::http::HeaderMap,
186) -> Result<User, Response> {
187    let cookie_header = headers
188        .get(COOKIE)
189        .and_then(|v| v.to_str().ok())
190        .ok_or_else(unauthenticated)?
191        .to_string();
192
193    let token = parse_session_cookie(&cookie_header, client.session_cookie_name())
194        .ok_or_else(unauthenticated)?;
195
196    let user = client
197        .validate_session(&token)
198        .await
199        .map_err(internal_error)?
200        .ok_or_else(unauthenticated)?;
201
202    Ok(user)
203}
204
205fn unauthenticated() -> Response {
206    (
207        StatusCode::UNAUTHORIZED,
208        axum::Json(json!({"error": "unauthenticated"})),
209    )
210        .into_response()
211}
212
213fn internal_error(err: AuthError) -> Response {
214    tracing::error!("auth middleware error: {err}");
215    (
216        StatusCode::INTERNAL_SERVER_ERROR,
217        axum::Json(json!({"error": "internal error"})),
218    )
219        .into_response()
220}
221
222#[cfg(test)]
223mod tests {
224    use std::sync::Arc;
225
226    use super::*;
227    use allowthem_core::{
228        AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, generate_token,
229        hash_token,
230    };
231    use axum::extract::FromRef;
232    use axum::http::StatusCode;
233    use axum::routing::get;
234    use axum::{Router, middleware};
235    use chrono::{Duration, Utc};
236    use tower::ServiceExt;
237
238    #[derive(Clone)]
239    struct TestState {
240        auth: Arc<dyn AuthClient>,
241    }
242
243    impl FromRef<TestState> for Arc<dyn AuthClient> {
244        fn from_ref(s: &TestState) -> Self {
245            Arc::clone(&s.auth)
246        }
247    }
248
249    /// Build AllowThem, create a user with an active session, return (AllowThem, cookie_value).
250    async fn test_setup() -> (AllowThem, String) {
251        let ath = AllowThemBuilder::new("sqlite::memory:")
252            .cookie_secure(false)
253            .build()
254            .await
255            .unwrap();
256
257        let email = Email::new("user@example.com".into()).unwrap();
258        let user = ath
259            .db()
260            .create_user(email, "password123", None)
261            .await
262            .unwrap();
263
264        let token = generate_token();
265        let token_hash = hash_token(&token);
266        let expires = Utc::now() + Duration::hours(24);
267        ath.db()
268            .create_session(user.id, token_hash, None, None, expires)
269            .await
270            .unwrap();
271
272        let cookie = ath.session_cookie(&token);
273        let cookie_value = cookie.split(';').next().unwrap().to_string();
274        (ath, cookie_value)
275    }
276
277    async fn ok_handler() -> StatusCode {
278        StatusCode::OK
279    }
280
281    fn auth_app(ath: AllowThem) -> Router {
282        let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
283        let state = TestState { auth };
284        Router::new()
285            .route("/protected", get(ok_handler))
286            .layer(middleware::from_fn_with_state(
287                state.clone(),
288                require_auth::<TestState>,
289            ))
290            .with_state(state)
291    }
292
293    fn role_app(ath: AllowThem, role: &str) -> Router {
294        let role = role.to_string();
295        let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
296        let state = TestState { auth };
297        Router::new()
298            .route("/protected", get(ok_handler))
299            .layer(middleware::from_fn_with_state(
300                state.clone(),
301                require_role::<TestState>(role),
302            ))
303            .with_state(state)
304    }
305
306    fn perm_app(ath: AllowThem, perm: &str) -> Router {
307        let perm = perm.to_string();
308        let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
309        let state = TestState { auth };
310        Router::new()
311            .route("/protected", get(ok_handler))
312            .layer(middleware::from_fn_with_state(
313                state.clone(),
314                require_permission::<TestState>(perm),
315            ))
316            .with_state(state)
317    }
318
319    fn make_request(cookie: Option<&str>) -> axum::http::Request<Body> {
320        let mut builder = axum::http::Request::builder().uri("/protected");
321        if let Some(c) = cookie {
322            builder = builder.header(COOKIE, c);
323        }
324        builder.body(Body::empty()).unwrap()
325    }
326
327    #[tokio::test]
328    async fn authenticated_request_passes_through() {
329        let (ath, cookie) = test_setup().await;
330        let app = auth_app(ath);
331        let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
332        assert_eq!(resp.status(), StatusCode::OK);
333    }
334
335    #[tokio::test]
336    async fn unauthenticated_request_returns_401() {
337        let (ath, _) = test_setup().await;
338        let app = auth_app(ath);
339        let resp = app.oneshot(make_request(None)).await.unwrap();
340        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
341    }
342
343    #[tokio::test]
344    async fn require_role_with_correct_role_passes() {
345        let (ath, cookie) = test_setup().await;
346
347        // Create role and assign to user.
348        let rn = allowthem_core::RoleName::new("admin");
349        let role = ath.db().create_role(&rn, None).await.unwrap();
350        let email = Email::new("user@example.com".into()).unwrap();
351        let user = ath.db().get_user_by_email(&email).await.unwrap();
352        ath.db().assign_role(&user.id, &role.id).await.unwrap();
353
354        let app = role_app(ath, "admin");
355        let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
356        assert_eq!(resp.status(), StatusCode::OK);
357    }
358
359    #[tokio::test]
360    async fn require_role_with_wrong_role_returns_403() {
361        let (ath, cookie) = test_setup().await;
362        // User has no roles assigned.
363        let app = role_app(ath, "admin");
364        let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
365        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
366    }
367
368    #[tokio::test]
369    async fn require_permission_with_correct_permission_passes() {
370        let (ath, cookie) = test_setup().await;
371
372        // Create permission and assign directly to user.
373        let pn = allowthem_core::PermissionName::new("posts:write");
374        let perm = ath.db().create_permission(&pn, None).await.unwrap();
375        let email = Email::new("user@example.com".into()).unwrap();
376        let user = ath.db().get_user_by_email(&email).await.unwrap();
377        ath.db()
378            .assign_permission_to_user(&user.id, &perm.id)
379            .await
380            .unwrap();
381
382        let app = perm_app(ath, "posts:write");
383        let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
384        assert_eq!(resp.status(), StatusCode::OK);
385    }
386
387    #[tokio::test]
388    async fn require_permission_with_missing_permission_returns_403() {
389        let (ath, cookie) = test_setup().await;
390        // User has no permissions.
391        let app = perm_app(ath, "posts:write");
392        let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
393        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
394    }
395}