Skip to main content

modkit_auth/
axum_ext.rs

1//! Axum extractors and middleware for auth
2
3use crate::{
4    claims::Claims,
5    errors::AuthError,
6    traits::{PrimaryAuthorizer, TokenValidator},
7    types::{AuthRequirement, RoutePolicy},
8};
9use axum::{
10    body::Body,
11    extract::{FromRequestParts, Request},
12    http::{HeaderMap, Method, request::Parts},
13    response::{IntoResponse, Response},
14};
15use modkit_security::SecurityContext;
16use std::{
17    future::Future,
18    pin::Pin,
19    sync::Arc,
20    task::{Context, Poll},
21};
22use tower::{Layer, Service};
23
24/// Extractor for `SecurityContext` - validates that auth middleware has run
25#[derive(Debug, Clone)]
26pub struct Authz(pub SecurityContext);
27
28impl<S> FromRequestParts<S> for Authz
29where
30    S: Send + Sync,
31{
32    type Rejection = AuthError;
33
34    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35        parts
36            .extensions
37            .get::<SecurityContext>()
38            .cloned() // TODO: drop this clone
39            .map(Authz)
40            .ok_or(AuthError::Internal(
41                "SecurityContext not found - auth middleware not configured".to_owned(),
42            ))
43    }
44}
45
46/// Extractor for Claims - validates that auth middleware has run
47#[derive(Debug, Clone)]
48pub struct AuthClaims(pub Claims);
49
50impl<S> FromRequestParts<S> for AuthClaims
51where
52    S: Send + Sync,
53{
54    type Rejection = AuthError;
55
56    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
57        parts
58            .extensions
59            .get::<Claims>()
60            .cloned() // TODO: drop this clone
61            .map(AuthClaims)
62            .ok_or(AuthError::Internal(
63                "Claims not found - auth middleware not configured".to_owned(),
64            ))
65    }
66}
67
68/// Shared state for authentication policy middleware.
69struct AuthPolicyState {
70    validator: Arc<dyn TokenValidator>,
71    authorizer: Arc<dyn PrimaryAuthorizer>,
72    policy: Arc<dyn RoutePolicy>,
73}
74
75/// Layer that applies authentication policy middleware to services.
76///
77/// # Example
78/// ```ignore
79/// router = router.layer(AuthPolicyLayer::new(validator, authorizer, policy));
80/// ```
81#[derive(Clone)]
82pub struct AuthPolicyLayer {
83    state: Arc<AuthPolicyState>,
84}
85
86impl AuthPolicyLayer {
87    pub fn new(
88        validator: Arc<dyn TokenValidator>,
89        authorizer: Arc<dyn PrimaryAuthorizer>,
90        policy: Arc<dyn RoutePolicy>,
91    ) -> Self {
92        Self {
93            state: Arc::new(AuthPolicyState {
94                validator,
95                authorizer,
96                policy,
97            }),
98        }
99    }
100}
101
102impl<S> Layer<S> for AuthPolicyLayer {
103    type Service = AuthPolicyService<S>;
104
105    fn layer(&self, inner: S) -> Self::Service {
106        AuthPolicyService {
107            inner,
108            state: self.state.clone(),
109        }
110    }
111}
112
113/// Service that applies authentication policy to requests.
114#[derive(Clone)]
115pub struct AuthPolicyService<S> {
116    inner: S,
117    state: Arc<AuthPolicyState>,
118}
119
120impl<S> Service<Request<Body>> for AuthPolicyService<S>
121where
122    S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
123    S::Future: Send,
124{
125    type Response = Response;
126    type Error = S::Error;
127    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        self.inner.poll_ready(cx)
131    }
132
133    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
134        let state = self.state.clone();
135        let not_ready_inner = self.inner.clone();
136        let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
137
138        Box::pin(async move {
139            // 1. Skips authentication for CORS preflight requests
140            if is_preflight_request(request.method(), request.headers()) {
141                return ready_inner.call(request).await;
142            }
143
144            // 2. Resolves the route's authentication requirement using RoutePolicy
145            let auth_requirement = state
146                .policy
147                .resolve(request.method(), request.uri().path())
148                .await;
149
150            match auth_requirement {
151                AuthRequirement::None => {
152                    // 3. For public routes (AuthRequirement::None): inserts anonymous SecurityContext
153                    request
154                        .extensions_mut()
155                        .insert(SecurityContext::anonymous());
156                    ready_inner.call(request).await
157                }
158                AuthRequirement::Required(sec_requirement) => {
159                    // 4. For required routes: validates JWT, enforces RBAC if needed, inserts SecurityContext
160                    let Some(token) = extract_bearer_token(request.headers()) else {
161                        return Ok(AuthError::Unauthenticated.into_response());
162                    };
163
164                    let claims = match state.validator.validate_and_parse(token).await {
165                        Ok(claims) => claims,
166                        Err(err) => {
167                            return Ok(err.into_response());
168                        }
169                    };
170
171                    // Optional RBAC requirement
172                    if let Some(sec_req) = sec_requirement
173                        && let Err(err) = state.authorizer.check(&claims, &sec_req).await
174                    {
175                        return Ok(err.into_response());
176                    }
177
178                    // Build SecurityContext from validated claims
179                    let sec_context = SecurityContext::builder()
180                        .tenant_id(claims.tenant_id)
181                        .subject_id(claims.subject)
182                        .build();
183
184                    request.extensions_mut().insert(claims);
185                    request.extensions_mut().insert(sec_context);
186                    ready_inner.call(request).await
187                }
188                AuthRequirement::Optional => {
189                    // 5. For optional routes: validates JWT if present, otherwise inserts anonymous SecurityContext
190                    if let Some(token) = extract_bearer_token(request.headers()) {
191                        match state.validator.validate_and_parse(token).await {
192                            Ok(claims) => {
193                                // Build SecurityContext from validated claims
194                                let sec_context = SecurityContext::builder()
195                                    .tenant_id(claims.tenant_id)
196                                    .subject_id(claims.subject)
197                                    .build();
198
199                                request.extensions_mut().insert(claims);
200                                request.extensions_mut().insert(sec_context);
201                            }
202                            Err(err) => {
203                                tracing::debug!("Optional auth: invalid token: {err}");
204                                request
205                                    .extensions_mut()
206                                    .insert(SecurityContext::anonymous());
207                            }
208                        }
209                    } else {
210                        request
211                            .extensions_mut()
212                            .insert(SecurityContext::anonymous());
213                    }
214                    ready_inner.call(request).await
215                }
216            }
217        })
218    }
219}
220
221/// Extract Bearer token from Authorization header
222fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
223    headers
224        .get(axum::http::header::AUTHORIZATION)
225        .and_then(|v| v.to_str().ok())
226        .and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
227}
228
229/// Check if this is a CORS preflight request
230///
231/// Preflight requests are OPTIONS requests with:
232/// - Origin header present
233/// - Access-Control-Request-Method header present
234fn is_preflight_request(method: &Method, headers: &HeaderMap) -> bool {
235    method == Method::OPTIONS
236        && headers.contains_key(axum::http::header::ORIGIN)
237        && headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
238}
239
240// Note: Unit tests for AuthPolicyLayer are in tests/auth_integration.rs
241// Direct unit testing requires the full Axum middleware stack, so integration tests are more appropriate.