Skip to main content

oxide_framework_core/auth/
extract.rs

1//! Extractors: [`Authenticated`], [`OptionalAuth`], [`RequireRole`].
2
3use std::marker::PhantomData;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::response::{IntoResponse, Response};
8
9use super::claims::AuthClaims;
10use crate::response::ApiResponse;
11
12/// Requires a valid JWT (middleware must run — use [`crate::App::auth`]).
13pub struct Authenticated(pub AuthClaims);
14
15impl<S: Send + Sync> FromRequestParts<S> for Authenticated {
16    type Rejection = AuthRejection;
17
18    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
19        parts
20            .extensions
21            .get::<AuthClaims>()
22            .cloned()
23            .map(Authenticated)
24            .ok_or(AuthRejection::Unauthorized)
25    }
26}
27
28/// Present when the client sent a valid JWT; [`None`] for anonymous requests.
29pub struct OptionalAuth(pub Option<AuthClaims>);
30
31impl<S: Send + Sync> FromRequestParts<S> for OptionalAuth {
32    type Rejection = std::convert::Infallible;
33
34    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35        Ok(OptionalAuth(parts.extensions.get::<AuthClaims>().cloned()))
36    }
37}
38
39/// Map a zero-sized type to a role name (see [`RequireRole`]).
40///
41/// ```rust,ignore
42/// struct Admin;
43/// impl RoleName for Admin {
44///     const ROLE: &'static str = "admin";
45/// }
46///
47/// async fn admin_only(_: RequireRole<Admin>) -> ApiResponse<()> {
48///     ApiResponse::ok(())
49/// }
50/// ```
51pub trait RoleName: Send + Sync + 'static {
52    const ROLE: &'static str;
53}
54
55/// Role guard: `RequireRole<YourRoleMarker>` where `YourRoleMarker: RoleName`.
56#[derive(Debug, Clone, Copy)]
57pub struct RequireRole<R: RoleName>(PhantomData<R>);
58
59impl<S: Send + Sync, R: RoleName> FromRequestParts<S> for RequireRole<R> {
60    type Rejection = AuthRejection;
61
62    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
63        let claims = parts
64            .extensions
65            .get::<AuthClaims>()
66            .ok_or(AuthRejection::Unauthorized)?;
67        if claims.has_role(R::ROLE) {
68            Ok(RequireRole(PhantomData))
69        } else {
70            Err(AuthRejection::Forbidden)
71        }
72    }
73}
74
75/// Rejection for auth extractors.
76#[derive(Debug)]
77pub enum AuthRejection {
78    Unauthorized,
79    Forbidden,
80}
81
82impl IntoResponse for AuthRejection {
83    fn into_response(self) -> Response {
84        match self {
85            AuthRejection::Unauthorized => {
86                ApiResponse::<serde_json::Value>::unauthorized("authentication required").into_response()
87            }
88            AuthRejection::Forbidden => {
89                ApiResponse::<serde_json::Value>::forbidden("insufficient permissions").into_response()
90            }
91        }
92    }
93}
94