libauth_rs/middleware/
axum.rs

1use crate::error::{AuthError, AuthResult};
2use crate::providers::AuthProvider;
3use crate::types::{AuthProvider as AuthProviderType, AuthToken, UserContext};
4use axum::{
5    extract::{FromRequestParts, Request},
6    http::{request::Parts, StatusCode},
7    response::{IntoResponse, Response},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Extension for accessing UserContext in handlers
13#[derive(Debug, Clone)]
14pub struct AuthExtension {
15    pub user: UserContext,
16    /// Reference to the provider that authenticated this user
17    /// This allows for provider-specific authorization checks
18    provider_index: usize,
19}
20
21impl AuthExtension {
22    /// Get the user context
23    pub fn user(&self) -> &UserContext {
24        &self.user
25    }
26
27    /// Get mutable user context
28    pub fn user_mut(&mut self) -> &mut UserContext {
29        &mut self.user
30    }
31}
32
33/// Authentication state that holds configured providers
34#[derive(Clone)]
35pub struct AuthState {
36    providers: Arc<Vec<Box<dyn AuthProvider>>>,
37    required: bool,
38    /// Optional mapping from issuer to provider type for issuer-based routing
39    issuer_to_provider: Arc<Option<HashMap<String, AuthProviderType>>>,
40}
41
42impl AuthState {
43    /// Create a new auth state with providers
44    pub fn new(providers: Vec<Box<dyn AuthProvider>>) -> Self {
45        Self {
46            providers: Arc::new(providers),
47            required: false,
48            issuer_to_provider: Arc::new(None),
49        }
50    }
51
52    /// Set whether authentication is required (returns 401 if not authenticated)
53    pub fn required(mut self, required: bool) -> Self {
54        self.required = required;
55        self
56    }
57
58    /// Set issuer-to-provider mapping for routing based on JWT issuer (iss) claim
59    pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
60        self.issuer_to_provider = Arc::new(Some(mapping));
61        self
62    }
63
64    /// Extract issuer (iss) from JWT token without full verification
65    fn extract_issuer_from_token(&self, token: &str) -> Option<String> {
66        use base64::Engine;
67
68        let parts: Vec<&str> = token.split('.').collect();
69        if parts.len() != 3 {
70            return None;
71        }
72
73        // Decode payload
74        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
75            .decode(parts[1])
76            .ok()?;
77
78        let payload: serde_json::Value = serde_json::from_slice(&payload).ok()?;
79
80        // Extract issuer (iss) claim
81        payload
82            .get("iss")
83            .and_then(|v| v.as_str())
84            .map(|s| s.to_string())
85    }
86
87    /// Try to authenticate a token using configured providers
88    /// If issuer-based routing is configured, routes to the appropriate provider first
89    /// Returns (UserContext, provider_index)
90    async fn authenticate_token(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
91        // Try issuer-based routing first if configured
92        if let Some(ref issuer_map) = *self.issuer_to_provider {
93            if let Some(issuer) = self.extract_issuer_from_token(&token.token) {
94                if let Some(target_provider_type) = issuer_map.get(&issuer) {
95                    // Find the provider that matches the target type
96                    for (index, provider) in self.providers.iter().enumerate() {
97                        if provider.provider_type() == *target_provider_type {
98                            tracing::debug!(
99                                "Routing to provider {:?} based on issuer: {}",
100                                target_provider_type,
101                                issuer
102                            );
103                            let user = provider.authenticate(token).await?;
104                            return Ok((user, index));
105                        }
106                    }
107                    tracing::warn!(
108                        "Provider {:?} not found for issuer: {}",
109                        target_provider_type,
110                        issuer
111                    );
112                }
113            }
114        }
115
116        // Fallback: try all providers using can_handle
117        for (index, provider) in self.providers.iter().enumerate() {
118            if provider.can_handle(token).await {
119                let user = provider.authenticate(token).await?;
120                return Ok((user, index));
121            }
122        }
123
124        Err(AuthError::ProviderError(
125            "No provider could handle the token".to_string(),
126        ))
127    }
128
129    /// Get a provider by index (used for authorization checks)
130    pub fn get_provider(&self, index: usize) -> Option<&dyn AuthProvider> {
131        self.providers.get(index).map(|p| p.as_ref())
132    }
133}
134
135/// Tower Layer for authentication middleware
136#[derive(Clone)]
137pub struct AuthLayer {
138    state: AuthState,
139}
140
141impl AuthLayer {
142    /// Create a new auth layer with providers
143    pub fn new(providers: Vec<Box<dyn AuthProvider>>) -> Self {
144        Self {
145            state: AuthState::new(providers),
146        }
147    }
148
149    /// Set whether authentication is required
150    pub fn required(mut self, required: bool) -> Self {
151        self.state = self.state.required(required);
152        self
153    }
154
155    /// Set issuer-to-provider mapping for routing based on JWT issuer (iss) claim
156    ///
157    /// # Example
158    /// ```rust,ignore
159    /// let config = AuthConfig::default();
160    /// let tenant_id = config.azure_tenant_id.as_ref().unwrap();
161    ///
162    /// let auth_layer = AuthLayer::new(providers)
163    ///     .with_issuer_mapping(HashMap::from([
164    ///         ("https://clerk.example.com".to_string(), AuthProvider::Clerk),
165    ///         (format!("https://login.microsoftonline.com/{}/v2.0", tenant_id), AuthProvider::Msal),
166    ///         (format!("https://sts.windows.net/{}/", tenant_id), AuthProvider::Msal),
167    ///     ]));
168    /// ```
169    pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
170        self.state = self.state.with_issuer_mapping(mapping);
171        self
172    }
173
174    /// Automatically build issuer mapping from the configured providers
175    /// This uses each provider's `expected_issuers()` method to build the mapping
176    ///
177    /// # Example
178    /// ```rust,ignore
179    /// let auth_layer = AuthLayer::new(providers)
180    ///     .with_auto_issuer_mapping(); // Automatically maps based on provider config
181    /// ```
182    pub fn with_auto_issuer_mapping(mut self) -> Self {
183        let mut mapping = HashMap::new();
184
185        for provider in self.state.providers.iter() {
186            let provider_type = provider.provider_type();
187            for issuer in provider.expected_issuers() {
188                mapping.insert(issuer, provider_type);
189            }
190        }
191
192        if !mapping.is_empty() {
193            self.state = self.state.with_issuer_mapping(mapping);
194        }
195
196        self
197    }
198}
199
200impl<S> tower_layer::Layer<S> for AuthLayer {
201    type Service = AuthMiddleware<S>;
202
203    fn layer(&self, inner: S) -> Self::Service {
204        AuthMiddleware {
205            inner,
206            state: self.state.clone(),
207        }
208    }
209}
210
211/// Middleware service
212#[derive(Clone)]
213pub struct AuthMiddleware<S> {
214    inner: S,
215    state: AuthState,
216}
217
218impl<S> tower::Service<Request> for AuthMiddleware<S>
219where
220    S: tower::Service<Request, Response = Response> + Clone + Send + 'static,
221    S::Future: Send + 'static,
222{
223    type Response = S::Response;
224    type Error = S::Error;
225    type Future = std::pin::Pin<
226        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
227    >;
228
229    fn poll_ready(
230        &mut self,
231        cx: &mut std::task::Context<'_>,
232    ) -> std::task::Poll<Result<(), Self::Error>> {
233        self.inner.poll_ready(cx)
234    }
235
236    fn call(&mut self, mut req: Request) -> Self::Future {
237        let clone = self.inner.clone();
238        let mut inner = std::mem::replace(&mut self.inner, clone);
239        let state = self.state.clone();
240
241        Box::pin(async move {
242            // Try to extract token from Authorization header
243            let auth_header = req
244                .headers()
245                .get(axum::http::header::AUTHORIZATION)
246                .and_then(|h| h.to_str().ok());
247
248            let auth_result = if let Some(header_value) = auth_header {
249                if let Some(token) = AuthToken::from_auth_header(header_value) {
250                    match state.authenticate_token(&token).await {
251                        Ok((user, provider_index)) => {
252                            // Security: Log successful authentication
253                            tracing::info!(
254                                user_id = %user.user_id,
255                                provider = %user.provider,
256                                "Authentication successful"
257                            );
258                            Some((user, provider_index))
259                        }
260                        Err(e) => {
261                            // Security: Log failed authentication attempt
262                            tracing::warn!(
263                                error = %e,
264                                provider = ?e.provider(),
265                                "Authentication failed"
266                            );
267                            if state.required {
268                                return Ok(StatusCode::UNAUTHORIZED.into_response());
269                            }
270                            None
271                        }
272                    }
273                } else {
274                    // Security: Log malformed auth header
275                    tracing::warn!("Malformed Authorization header");
276                    if state.required {
277                        return Ok(StatusCode::UNAUTHORIZED.into_response());
278                    }
279                    None
280                }
281            } else {
282                // Security: Log missing auth when required
283                if state.required {
284                    tracing::warn!("Missing Authorization header for protected endpoint");
285                    return Ok(StatusCode::UNAUTHORIZED.into_response());
286                }
287                None
288            };
289
290            // Insert user context into request extensions if authenticated
291            if let Some((user, provider_index)) = auth_result {
292                req.extensions_mut().insert(AuthExtension {
293                    user,
294                    provider_index,
295                });
296                req.extensions_mut().insert(state.clone());
297            }
298
299            inner.call(req).await
300        })
301    }
302}
303
304/// Extractor for UserContext from request
305impl<S> FromRequestParts<S> for AuthExtension
306where
307    S: Send + Sync,
308{
309    type Rejection = AuthRejection;
310
311    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
312        parts
313            .extensions
314            .get::<AuthExtension>()
315            .cloned()
316            .ok_or(AuthRejection::MissingAuth)
317    }
318}
319
320/// Extractor for optional UserContext
321pub struct OptionalAuth(pub Option<UserContext>);
322
323impl<S> FromRequestParts<S> for OptionalAuth
324where
325    S: Send + Sync,
326{
327    type Rejection = std::convert::Infallible;
328
329    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
330        Ok(OptionalAuth(
331            parts
332                .extensions
333                .get::<AuthExtension>()
334                .map(|ext| ext.user.clone()),
335        ))
336    }
337}
338
339/// Extractor for AuthContext with provider access
340/// This allows handlers to perform provider-specific authorization checks
341pub struct AuthContext {
342    pub user: UserContext,
343    state: AuthState,
344    provider_index: usize,
345}
346
347impl AuthContext {
348    /// Get the user context
349    pub fn user(&self) -> &UserContext {
350        &self.user
351    }
352
353    /// Check if user has a specific permission (provider-specific)
354    pub async fn check_permission(&self, permission: &str) -> AuthResult<bool> {
355        if let Some(provider) = self.state.get_provider(self.provider_index) {
356            provider.check_permission(&self.user, permission).await
357        } else {
358            Ok(false)
359        }
360    }
361
362    /// Check if user has a specific role (provider-specific)
363    pub async fn check_role(&self, role: &str) -> AuthResult<bool> {
364        if let Some(provider) = self.state.get_provider(self.provider_index) {
365            provider.check_role(&self.user, role).await
366        } else {
367            Ok(false)
368        }
369    }
370
371    /// Get all user roles (provider-specific)
372    pub async fn get_roles(&self) -> AuthResult<Vec<String>> {
373        if let Some(provider) = self.state.get_provider(self.provider_index) {
374            provider.get_user_roles(&self.user).await
375        } else {
376            Ok(vec![])
377        }
378    }
379
380    /// Check if user belongs to a specific organization (provider-specific)
381    pub async fn check_organization(&self, org_id: &str) -> AuthResult<bool> {
382        if let Some(provider) = self.state.get_provider(self.provider_index) {
383            provider.check_organization(&self.user, org_id).await
384        } else {
385            Ok(false)
386        }
387    }
388
389    /// Require a specific permission or return 403 Forbidden
390    pub async fn require_permission(&self, permission: &str) -> Result<(), StatusCode> {
391        if self.check_permission(permission).await.unwrap_or(false) {
392            Ok(())
393        } else {
394            Err(StatusCode::FORBIDDEN)
395        }
396    }
397
398    /// Require a specific role or return 403 Forbidden
399    pub async fn require_role(&self, role: &str) -> Result<(), StatusCode> {
400        if self.check_role(role).await.unwrap_or(false) {
401            Ok(())
402        } else {
403            Err(StatusCode::FORBIDDEN)
404        }
405    }
406}
407
408impl<S> FromRequestParts<S> for AuthContext
409where
410    S: Send + Sync,
411{
412    type Rejection = AuthRejection;
413
414    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
415        let auth_ext = parts
416            .extensions
417            .get::<AuthExtension>()
418            .ok_or(AuthRejection::MissingAuth)?;
419
420        let state = parts
421            .extensions
422            .get::<AuthState>()
423            .ok_or(AuthRejection::MissingAuth)?;
424
425        Ok(AuthContext {
426            user: auth_ext.user.clone(),
427            state: state.clone(),
428            provider_index: auth_ext.provider_index,
429        })
430    }
431}
432
433/// Rejection type for auth extraction
434#[derive(Debug)]
435pub enum AuthRejection {
436    MissingAuth,
437}
438
439impl IntoResponse for AuthRejection {
440    fn into_response(self) -> Response {
441        let (status, message) = match self {
442            AuthRejection::MissingAuth => (StatusCode::UNAUTHORIZED, "Authentication required"),
443        };
444
445        (status, message).into_response()
446    }
447}
448
449// Example usage documentation module
450#[doc = r##"
451# Example Usage
452
453Using the authentication middleware in an Axum router:
454
455```rust,ignore
456use axum::{Router, routing::get, http::StatusCode};
457use libauth_rs::middleware::{AuthLayer, AuthExtension, AuthContext};
458use libauth_rs::providers::{AuthConfig, ClerkProvider, MsalProvider};
459
460#[tokio::main]
461async fn main() {
462    let config = AuthConfig::default();
463
464    // Create multiple providers
465    let clerk = ClerkProvider::new(&config).await.unwrap();
466    let msal = MsalProvider::new(&config).await.unwrap();
467
468    // Create auth layer with automatic issuer-based routing
469    let auth_layer = AuthLayer::new(vec![Box::new(clerk), Box::new(msal)])
470        .with_auto_issuer_mapping() // Automatically route based on JWT issuer
471        .required(false);
472
473    let app = Router::new()
474        .route("/public", get(public_handler))
475        .route("/protected", get(protected_handler))
476        .route("/admin", get(admin_handler))
477        .layer(auth_layer);
478
479    // Run your server...
480}
481
482// Public handler - no auth required
483async fn public_handler(OptionalAuth(user): OptionalAuth) -> String {
484    match user {
485        Some(u) => format!("Hello, {}!", u.user_id),
486        None => "Hello, anonymous!".to_string(),
487    }
488}
489
490// Protected handler - auth required, basic user info
491async fn protected_handler(AuthExtension { user, .. }: AuthExtension) -> String {
492    format!("Welcome, {}!", user.user_id)
493}
494
495// Admin handler - auth required + role check using AuthContext
496async fn admin_handler(auth: AuthContext) -> Result<String, StatusCode> {
497    // Use provider-specific authorization
498    auth.require_role("admin").await?;
499
500    let roles = auth.get_roles().await.unwrap_or_default();
501    Ok(format!("Admin access granted! Your roles: {:?}", roles))
502}
503
504// Per-provider routers example
505async fn setup_per_provider_routers() -> Router {
506    let config = AuthConfig::default();
507    let clerk = ClerkProvider::new(&config).await.unwrap();
508    let msal = MsalProvider::new(&config).await.unwrap();
509
510    // Clerk-specific routes
511    let clerk_router = Router::new()
512        .route("/org/:org_id/members", get(clerk_org_members))
513        .layer(AuthLayer::new(vec![Box::new(clerk)]).required(true));
514
515    // MSAL-specific routes
516    let msal_router = Router::new()
517        .route("/azure/groups", get(msal_groups))
518        .layer(AuthLayer::new(vec![Box::new(msal)]).required(true));
519
520    // Combine routers
521    Router::new()
522        .nest("/clerk", clerk_router)
523        .nest("/msal", msal_router)
524}
525```
526"##]
527pub(crate) mod _example {}